Skip to content

Commit

Permalink
fix(sqla-query): order by aggregations in Presto and Hive (apache#13739)
Browse files Browse the repository at this point in the history
  • Loading branch information
ktmud authored Apr 2, 2021
1 parent 7621010 commit 4789074
Show file tree
Hide file tree
Showing 10 changed files with 315 additions and 108 deletions.
130 changes: 88 additions & 42 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
import json
import logging
import re
from collections import defaultdict, OrderedDict
from contextlib import closing
from dataclasses import dataclass, field # pylint: disable=wrong-import-order
Expand Down Expand Up @@ -50,6 +51,7 @@
from sqlalchemy.sql import column, ColumnElement, literal_column, table, text
from sqlalchemy.sql.elements import ColumnClause
from sqlalchemy.sql.expression import Label, Select, TextAsFrom, TextClause
from sqlalchemy.sql.selectable import Alias, TableClause
from sqlalchemy.types import TypeEngine

from superset import app, db, is_feature_enabled, security_manager
Expand All @@ -70,7 +72,7 @@
from superset.sql_parse import ParsedQuery
from superset.typing import AdhocMetric, Metric, OrderBy, QueryObjectDict
from superset.utils import core as utils
from superset.utils.core import GenericDataType
from superset.utils.core import GenericDataType, remove_duplicates

config = app.config
metadata = Model.metadata # pylint: disable=no-member
Expand Down Expand Up @@ -465,7 +467,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
database_id = Column(Integer, ForeignKey("dbs.id"), nullable=False)
fetch_values_predicate = Column(String(1000))
owners = relationship(owner_class, secondary=sqlatable_user, backref="tables")
database = relationship(
database: Database = relationship(
"Database",
backref=backref("tables", cascade="all, delete-orphan"),
foreign_keys=[database_id],
Expand Down Expand Up @@ -507,22 +509,6 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
"MAX": sa.func.MAX,
}

def make_sqla_column_compatible(
self, sqla_col: Column, label: Optional[str] = None
) -> Column:
"""Takes a sqlalchemy column object and adds label info if supported by engine.
:param sqla_col: sqlalchemy column instance
:param label: alias/label that column is expected to have
:return: either a sql alchemy column or label instance if supported by engine
"""
label_expected = label or sqla_col.name
db_engine_spec = self.database.db_engine_spec
# add quotes to tables
if db_engine_spec.allows_alias_in_select:
label = db_engine_spec.make_label_compatible(label_expected)
sqla_col = sqla_col.label(label)
return sqla_col

def __repr__(self) -> str:
return self.name

Expand Down Expand Up @@ -708,11 +694,10 @@ def health_check_message(self) -> Optional[str]:
def data(self) -> Dict[str, Any]:
data_ = super().data
if self.type == "table":
grains = self.database.grains() or []
if grains:
grains = [(g.duration, g.name) for g in grains]
data_["granularity_sqla"] = utils.choicify(self.dttm_cols)
data_["time_grain_sqla"] = grains
data_["time_grain_sqla"] = [
(g.duration, g.name) for g in self.database.grains() or []
]
data_["main_dttm_col"] = self.main_dttm_col
data_["fetch_values_predicate"] = self.fetch_values_predicate
data_["template_params"] = self.template_params
Expand Down Expand Up @@ -800,15 +785,15 @@ def get_query_str(self, query_obj: QueryObjectDict) -> str:
all_queries = query_str_ext.prequeries + [query_str_ext.sql]
return ";\n\n".join(all_queries) + ";"

def get_sqla_table(self) -> table:
def get_sqla_table(self) -> TableClause:
tbl = table(self.table_name)
if self.schema:
tbl.schema = self.schema
return tbl

def get_from_clause(
self, template_processor: Optional[BaseTemplateProcessor] = None
) -> Union[table, TextAsFrom]:
) -> Union[TableClause, Alias]:
"""
Return where to select the columns and metrics from. Either a physical table
or a virtual table with it's own subquery.
Expand Down Expand Up @@ -882,6 +867,51 @@ def adhoc_metric_to_sqla(

return self.make_sqla_column_compatible(sqla_metric, label)

def make_sqla_column_compatible(
self, sqla_col: Column, label: Optional[str] = None
) -> Column:
"""Takes a sqlalchemy column object and adds label info if supported by engine.
:param sqla_col: sqlalchemy column instance
:param label: alias/label that column is expected to have
:return: either a sql alchemy column or label instance if supported by engine
"""
label_expected = label or sqla_col.name
db_engine_spec = self.database.db_engine_spec
# add quotes to tables
if db_engine_spec.allows_alias_in_select:
label = db_engine_spec.make_label_compatible(label_expected)
sqla_col = sqla_col.label(label)
return sqla_col

def make_orderby_compatible(
self, select_exprs: List[ColumnElement], orderby_exprs: List[ColumnElement]
) -> None:
"""
If needed, make sure aliases for selected columns are not used in
`ORDER BY`.
In some databases (e.g. Presto), `ORDER BY` clause is not able to
automatically pick the source column if a `SELECT` clause alias is named
the same as a source column. In this case, we update the SELECT alias to
another name to avoid the conflict.
"""
if self.database.db_engine_spec.allows_alias_to_source_column:
return

def is_alias_used_in_orderby(col: ColumnElement) -> bool:
if not isinstance(col, Label):
return False
regexp = re.compile(f"\\(.*\\b{re.escape(col.name)}\\b.*\\)", re.IGNORECASE)
return any(regexp.search(str(x)) for x in orderby_exprs)

# Iterate through selected columns, if column alias appears in orderby
# use another `alias`. The final output columns will still use the
# original names, because they are updated by `labels_expected` after
# querying.
for col in select_exprs:
if is_alias_used_in_orderby(col):
col.name = f"{col.name}__"

def _get_sqla_row_level_filters(
self, template_processor: BaseTemplateProcessor
) -> List[str]:
Expand Down Expand Up @@ -995,9 +1025,8 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma

# To ensure correct handling of the ORDER BY labeling we need to reference the
# metric instance if defined in the SELECT clause.
metrics_exprs_by_label = {
m.name: m for m in metrics_exprs # pylint: disable=protected-access
}
metrics_exprs_by_label = {m.name: m for m in metrics_exprs}
metrics_exprs_by_expr = {str(m): m for m in metrics_exprs}

# Since orderby may use adhoc metrics, too; we need to process them first
orderby_exprs: List[ColumnElement] = []
Expand All @@ -1007,21 +1036,25 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
if utils.is_adhoc_metric(col):
# add adhoc sort by column to columns_by_name if not exists
col = self.adhoc_metric_to_sqla(col, columns_by_name)
# if the adhoc metric has been defined before
# use the existing instance.
col = metrics_exprs_by_expr.get(str(col), col)
need_groupby = True
elif col in columns_by_name:
col = columns_by_name[col].get_sqla_col()
elif col in metrics_exprs_by_label:
col = metrics_exprs_by_label[col]
need_groupby = True
elif col in metrics_by_name:
col = metrics_by_name[col].get_sqla_col()
need_groupby = True
elif col in metrics_exprs_by_label:
col = metrics_exprs_by_label[col]

if isinstance(col, ColumnElement):
orderby_exprs.append(col)
else:
# Could not convert a column reference to valid ColumnElement
raise QueryObjectValidationError(
_("Unknown column used in orderby: %(col)", col=orig_col)
_("Unknown column used in orderby: %(col)s", col=orig_col)
)

select_exprs: List[Union[Column, Label]] = []
Expand Down Expand Up @@ -1093,11 +1126,21 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
dttm_col.get_time_filter(from_dttm, to_dttm, time_range_endpoints)
)

select_exprs += metrics_exprs
labels_expected = [c.name for c in select_exprs]
select_exprs = db_engine_spec.make_select_compatible(
groupby_exprs_with_timestamp.values(), select_exprs
# Always remove duplicates by column name, as sometimes `metrics_exprs`
# can have the same name as a groupby column (e.g. when users use
# raw columns as custom SQL adhoc metric).
select_exprs = remove_duplicates(
select_exprs + metrics_exprs, key=lambda x: x.name
)

# Expected output columns
labels_expected = [c.name for c in select_exprs]

# Order by columns are "hidden" columns, some databases require them
# always be present in SELECT if an aggregation function is used
if not db_engine_spec.allows_hidden_ordeby_agg:
select_exprs = remove_duplicates(select_exprs + orderby_exprs)

qry = sa.select(select_exprs)

tbl = self.get_from_clause(template_processor)
Expand Down Expand Up @@ -1213,12 +1256,13 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
qry = qry.where(and_(*where_clause_and))
qry = qry.having(and_(*having_clause_and))

self.make_orderby_compatible(select_exprs, orderby_exprs)

for col, (orig_col, ascending) in zip(orderby_exprs, orderby):
if (
db_engine_spec.allows_alias_in_orderby
and col.name in metrics_exprs_by_label
):
col = Label(col.name, metrics_exprs_by_label[col.name])
if not db_engine_spec.allows_alias_in_orderby and isinstance(col, Label):
# if engine does not allow using SELECT alias in ORDER BY
# revert to the underlying column
col = col.element
direction = asc if ascending else desc
qry = qry.order_by(direction(col))

Expand Down Expand Up @@ -1315,17 +1359,19 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
result.df, dimensions, groupby_exprs_sans_timestamp
)
qry = qry.where(top_groups)

qry = qry.select_from(tbl)

if is_rowcount:
if not db_engine_spec.allows_subqueries:
raise QueryObjectValidationError(
_("Database does not support subqueries")
)
label = "rowcount"
col = self.make_sqla_column_compatible(literal_column("COUNT(*)"), label)
qry = select([col]).select_from(qry.select_from(tbl).alias("rowcount_qry"))
qry = select([col]).select_from(qry.alias("rowcount_qry"))
labels_expected = [label]
else:
qry = qry.select_from(tbl)

return SqlaQuery(
extra_cache_keys=extra_cache_keys,
labels_expected=labels_expected,
Expand Down
38 changes: 22 additions & 16 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.orm import Session
from sqlalchemy.sql import quoted_name, text
from sqlalchemy.sql.expression import ColumnClause, ColumnElement, Select, TextAsFrom
from sqlalchemy.sql.expression import ColumnClause, Select, TextAsFrom
from sqlalchemy.types import String, TypeEngine, UnicodeText

from superset import app, security_manager, sql_parse
Expand Down Expand Up @@ -137,7 +137,18 @@ class LimitMethod: # pylint: disable=too-few-public-methods


class BaseEngineSpec: # pylint: disable=too-many-public-methods
"""Abstract class for database engine specific configurations"""
"""Abstract class for database engine specific configurations
Attributes:
allows_alias_to_source_column: Whether the engine is able to pick the
source column for aggregation clauses
used in ORDER BY when a column in SELECT
has an alias that is the same as a source
column.
allows_hidden_orderby_agg: Whether the engine allows ORDER BY to
directly use aggregation clauses, without
having to add the same aggregation in SELECT.
"""

engine = "base" # str as defined in sqlalchemy.engine.engine
engine_aliases: Optional[Tuple[str]] = None
Expand Down Expand Up @@ -241,6 +252,15 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
allows_alias_in_select = True
allows_alias_in_orderby = True
allows_sql_comments = True

# Whether ORDER BY clause can use aliases created in SELECT
# that are the same as a source column
allows_alias_to_source_column = True

# Whether ORDER BY clause must appear in SELECT
# if TRUE, then it doesn't have to.
allows_hidden_ordeby_agg = True

force_column_alias_quotes = False
arraysize = 0
max_column_name_length = 0
Expand Down Expand Up @@ -441,20 +461,6 @@ def get_time_grain_expressions(cls) -> Dict[Optional[str], str]:
)
)

@classmethod
def make_select_compatible(
cls, groupby_exprs: Dict[str, ColumnElement], select_exprs: List[ColumnElement]
) -> List[ColumnElement]:
"""
Some databases will just return the group-by field into the select, but don't
allow the group-by field to be put into the select list.
:param groupby_exprs: mapping between column name and column object
:param select_exprs: all columns in the select clause
:return: columns to be included in the final select clause
"""
return select_exprs

@classmethod
def fetch_data(
cls, cursor: Any, limit: Optional[int] = None
Expand Down
3 changes: 3 additions & 0 deletions superset/db_engine_specs/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ class HiveEngineSpec(PrestoEngineSpec):
engine = "hive"
engine_name = "Apache Hive"
max_column_name_length = 767
allows_alias_to_source_column = True
allows_hidden_ordeby_agg = False

# pylint: disable=line-too-long
_time_grain_expressions = {
None: "{col}",
Expand Down
10 changes: 2 additions & 8 deletions superset/db_engine_specs/pinot.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Dict, List, Optional
from typing import Dict, Optional

from sqlalchemy.sql.expression import ColumnClause, ColumnElement
from sqlalchemy.sql.expression import ColumnClause

from superset.db_engine_specs.base import BaseEngineSpec, TimestampExpression

Expand Down Expand Up @@ -112,9 +112,3 @@ def get_timestamp_expr(
time_expr = f"DATETIMECONVERT({{col}}, '{tf}', '{tf}', '{granularity}')"

return TimestampExpression(time_expr, col)

@classmethod
def make_select_compatible(
cls, groupby_exprs: Dict[str, ColumnElement], select_exprs: List[ColumnElement]
) -> List[ColumnElement]:
return select_exprs
1 change: 1 addition & 0 deletions superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def get_children(column: Dict[str, str]) -> List[Dict[str, str]]:
class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-methods
engine = "presto"
engine_name = "Presto"
allows_alias_to_source_column = False

_time_grain_expressions = {
None: "{col}",
Expand Down
16 changes: 16 additions & 0 deletions superset/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1631,6 +1631,22 @@ def find_duplicates(items: Iterable[InputType]) -> List[InputType]:
return [item for item, count in collections.Counter(items).items() if count > 1]


def remove_duplicates(
items: Iterable[InputType], key: Optional[Callable[[InputType], Any]] = None
) -> List[InputType]:
"""Remove duplicate items in an iterable."""
if not key:
return list(dict.fromkeys(items).keys())
seen = set()
result = []
for item in items:
item_key = key(item)
if item_key not in seen:
seen.add(item_key)
result.append(item)
return result


def normalize_dttm_col(
df: pd.DataFrame,
timestamp_format: Optional[str],
Expand Down
Loading

0 comments on commit 4789074

Please sign in to comment.