Skip to content

Commit

Permalink
fix: handle query exceptions gracefully (apache#10548)
Browse files Browse the repository at this point in the history
* fix: handle query exceptions gracefully

* add more recasts

* add test

* disable test for presto

* switch to SQLA error
  • Loading branch information
villebro authored Aug 7, 2020
1 parent ea0db0d commit 08358d6
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 41 deletions.
6 changes: 5 additions & 1 deletion superset/common/query_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from superset.common.query_object import QueryObject
from superset.connectors.base.models import BaseDatasource
from superset.connectors.connector_registry import ConnectorRegistry
from superset.exceptions import QueryObjectValidationError
from superset.stats_logger import BaseStatsLogger
from superset.utils import core as utils
from superset.utils.core import DTTM_ALIAS
Expand Down Expand Up @@ -244,10 +245,13 @@ def get_df_payload( # pylint: disable=too-many-locals,too-many-statements
if not self.force:
stats_logger.incr("loaded_from_source_without_force")
is_loaded = True
except QueryObjectValidationError as ex:
error_message = str(ex)
status = utils.QueryStatus.FAILED
except Exception as ex: # pylint: disable=broad-except
logger.exception(ex)
if not error_message:
error_message = "{}".format(ex)
error_message = str(ex)
status = utils.QueryStatus.FAILED
stacktrace = utils.get_stacktrace()

Expand Down
85 changes: 63 additions & 22 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from flask import escape, Markup
from flask_appbuilder import Model
from flask_babel import lazy_gettext as _
from jinja2.exceptions import TemplateError
from sqlalchemy import (
and_,
asc,
Expand All @@ -40,7 +41,7 @@
Table,
Text,
)
from sqlalchemy.exc import CompileError
from sqlalchemy.exc import CompileError, SQLAlchemyError
from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty, Session
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.schema import UniqueConstraint
Expand All @@ -51,7 +52,7 @@
from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric
from superset.constants import NULL_STRING
from superset.db_engine_specs.base import TimestampExpression
from superset.exceptions import DatabaseNotFound
from superset.exceptions import DatabaseNotFound, QueryObjectValidationError
from superset.jinja_context import (
BaseTemplateProcessor,
ExtraCache,
Expand Down Expand Up @@ -634,7 +635,15 @@ def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:

if self.fetch_values_predicate:
tp = self.get_template_processor()
qry = qry.where(text(tp.process_template(self.fetch_values_predicate)))
try:
qry = qry.where(text(tp.process_template(self.fetch_values_predicate)))
except TemplateError as ex:
raise QueryObjectValidationError(
_(
"Error in jinja expression in fetch values predicate: %(msg)s",
msg=ex.message,
)
)

engine = self.database.get_sqla_engine()
sql = "{}".format(qry.compile(engine, compile_kwargs={"literal_binds": True}))
Expand Down Expand Up @@ -684,7 +693,16 @@ def get_from_clause(
if self.sql:
from_sql = self.sql
if template_processor:
from_sql = template_processor.process_template(from_sql)
try:
from_sql = template_processor.process_template(from_sql)
except TemplateError as ex:
raise QueryObjectValidationError(
_(
"Error in jinja expression in FROM clause: %(msg)s",
msg=ex.message,
)
)

from_sql = sqlparse.format(from_sql, strip_comments=True)
return TextAsFrom(sa.text(from_sql), []).alias("expr_qry")
return self.get_sqla_table()
Expand Down Expand Up @@ -730,10 +748,15 @@ def _get_sqla_row_level_filters(
:returns: A list of SQL clauses to be ANDed together.
:rtype: List[str]
"""
return [
text("({})".format(template_processor.process_template(f.clause)))
for f in security_manager.get_rls_filters(self)
]
try:
return [
text("({})".format(template_processor.process_template(f.clause)))
for f in security_manager.get_rls_filters(self)
]
except TemplateError as ex:
raise QueryObjectValidationError(
_("Error in jinja expression in RLS filters: %(msg)s", msg=ex.message,)
)

def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements
self,
Expand Down Expand Up @@ -791,7 +814,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
metrics_by_name: Dict[str, SqlMetric] = {m.metric_name: m for m in self.metrics}

if not granularity and is_timeseries:
raise Exception(
raise QueryObjectValidationError(
_(
"Datetime column not provided as part table configuration "
"and is required by this type of chart"
Expand All @@ -802,7 +825,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
and not columns
and (is_sip_38 or (not is_sip_38 and not groupby))
):
raise Exception(_("Empty query?"))
raise QueryObjectValidationError(_("Empty query?"))
metrics_exprs: List[ColumnElement] = []
for metric in metrics:
if utils.is_adhoc_metric(metric):
Expand All @@ -811,7 +834,9 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
elif isinstance(metric, str) and metric in metrics_by_name:
metrics_exprs.append(metrics_by_name[metric].get_sqla_col())
else:
raise Exception(_("Metric '%(metric)s' does not exist", metric=metric))
raise QueryObjectValidationError(
_("Metric '%(metric)s' does not exist", metric=metric)
)
if metrics_exprs:
main_metric_expr = metrics_exprs[0]
else:
Expand Down Expand Up @@ -958,19 +983,35 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
!= None
)
else:
raise Exception(
raise QueryObjectValidationError(
_("Invalid filter operation type: %(op)s", op=op)
)
if config["ENABLE_ROW_LEVEL_SECURITY"]:
where_clause_and += self._get_sqla_row_level_filters(template_processor)
if extras:
where = extras.get("where")
if where:
where = template_processor.process_template(where)
try:
where = template_processor.process_template(where)
except TemplateError as ex:
raise QueryObjectValidationError(
_(
"Error in jinja expression in WHERE clause: %(msg)s",
msg=ex.message,
)
)
where_clause_and += [sa.text("({})".format(where))]
having = extras.get("having")
if having:
having = template_processor.process_template(having)
try:
having = template_processor.process_template(having)
except TemplateError as ex:
raise QueryObjectValidationError(
_(
"Error in jinja expression in HAVING clause: %(msg)s",
msg=ex.message,
)
)
having_clause_and += [sa.text("({})".format(having))]
if granularity:
qry = qry.where(and_(*(time_filters + where_clause_and)))
Expand Down Expand Up @@ -1117,7 +1158,7 @@ def _get_timeseries_orderby(
):
ob = metrics_by_name[timeseries_limit_metric].get_sqla_col()
else:
raise Exception(
raise QueryObjectValidationError(
_("Metric '%(metric)s' does not exist", metric=timeseries_limit_metric)
)

Expand Down Expand Up @@ -1159,7 +1200,7 @@ def mutator(df: pd.DataFrame) -> None:
labels_expected = query_str_ext.labels_expected
if df is not None and not df.empty:
if len(df.columns) != len(labels_expected):
raise Exception(
raise QueryObjectValidationError(
f"For {sql}, df.columns: {df.columns}"
f" differs from {labels_expected}"
)
Expand Down Expand Up @@ -1193,13 +1234,13 @@ def fetch_metadata(self, commit: bool = True) -> None:
"""Fetches the metadata for the table and merges it in"""
try:
table_ = self.get_sqla_table_object()
except Exception as ex:
logger.exception(ex)
raise Exception(
except SQLAlchemyError:
raise QueryObjectValidationError(
_(
"Table [{}] doesn't seem to exist in the specified database, "
"couldn't fetch column information"
).format(self.table_name)
"Table %(table)s doesn't seem to exist in the specified database, "
"couldn't fetch column information",
table=self.table_name,
)
)

metrics = []
Expand Down
7 changes: 4 additions & 3 deletions superset/views/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from flask_appbuilder.security.decorators import has_access, has_access_api
from flask_appbuilder.security.sqla import models as ab_models
from flask_babel import gettext as __, lazy_gettext as _
from jinja2.exceptions import TemplateError
from sqlalchemy import and_, or_, select
from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import (
Expand Down Expand Up @@ -535,7 +536,7 @@ def explore_json(

return self.generate_json(viz_obj, response_type)
except SupersetException as ex:
return json_error_response(utils.error_msg_from_exception(ex))
return json_error_response(utils.error_msg_from_exception(ex), 400)

@event_logger.log_this
@has_access
Expand Down Expand Up @@ -2300,10 +2301,10 @@ def sql_json_exec( # pylint: disable=too-many-statements,too-many-locals
rendered_query = template_processor.process_template(
query.sql, **template_params
)
except Exception as ex: # pylint: disable=broad-except
except TemplateError as ex:
error_msg = utils.error_msg_from_exception(ex)
return json_error_response(
f"Query {query_id}: Template rendering failed: {error_msg}"
f"Query {query_id}: Template syntax error: {error_msg}"
)

# Limit is not applied to the CTA queries if SQLLAB_CTAS_NO_LIMIT flag is set
Expand Down
50 changes: 35 additions & 15 deletions superset/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,13 +329,17 @@ def query_obj(self) -> QueryObjectDict:
# default order direction
order_desc = form_data.get("order_desc", True)

since, until = utils.get_since_until(
relative_start=relative_start,
relative_end=relative_end,
time_range=form_data.get("time_range"),
since=form_data.get("since"),
until=form_data.get("until"),
)
try:
since, until = utils.get_since_until(
relative_start=relative_start,
relative_end=relative_end,
time_range=form_data.get("time_range"),
since=form_data.get("since"),
until=form_data.get("until"),
)
except ValueError as ex:
raise QueryObjectValidationError(str(ex))

time_shift = form_data.get("time_shift", "")
self.time_shift = utils.parse_past_timedelta(time_shift)
from_dttm = None if since is None else (since - self.time_shift)
Expand Down Expand Up @@ -475,6 +479,16 @@ def get_df_payload(
if not self.force:
stats_logger.incr("loaded_from_source_without_force")
is_loaded = True
except QueryObjectValidationError as ex:
error = dataclasses.asdict(
SupersetError(
message=str(ex),
level=ErrorLevel.ERROR,
error_type=SupersetErrorType.VIZ_GET_DF_ERROR,
)
)
self.errors.append(error)
self.status = utils.QueryStatus.FAILED
except Exception as ex:
logger.exception(ex)

Expand Down Expand Up @@ -889,13 +903,16 @@ def get_data(self, df: pd.DataFrame) -> VizData:
values[str(v / 10 ** 9)] = obj.get(metric)
data[metric] = values

start, end = utils.get_since_until(
relative_start=relative_start,
relative_end=relative_end,
time_range=form_data.get("time_range"),
since=form_data.get("since"),
until=form_data.get("until"),
)
try:
start, end = utils.get_since_until(
relative_start=relative_start,
relative_end=relative_end,
time_range=form_data.get("time_range"),
since=form_data.get("since"),
until=form_data.get("until"),
)
except ValueError as ex:
raise QueryObjectValidationError(str(ex))
if not start or not end:
raise QueryObjectValidationError(
"Please provide both time bounds (Since and Until)"
Expand Down Expand Up @@ -1288,7 +1305,10 @@ def run_extra_queries(self) -> None:

for option in time_compare:
query_object = self.query_obj()
delta = utils.parse_past_timedelta(option)
try:
delta = utils.parse_past_timedelta(option)
except ValueError as ex:
raise QueryObjectValidationError(str(ex))
query_object["inner_from_dttm"] = query_object["from_dttm"]
query_object["inner_to_dttm"] = query_object["to_dttm"]

Expand Down
25 changes: 25 additions & 0 deletions tests/sqla_models_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
# isort:skip_file
from typing import Any, Dict, NamedTuple, List, Tuple, Union
from unittest.mock import patch
import pytest

import tests.test_app
from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.db_engine_specs.druid import DruidEngineSpec
from superset.exceptions import QueryObjectValidationError
from superset.models.core import Database
from superset.utils.core import DbColumnType, get_example_database, FilterOperator

Expand Down Expand Up @@ -170,3 +172,26 @@ class FilterTestCase(NamedTuple):
sqla_query = table.get_sqla_query(**query_obj)
sql = table.database.compile_sqla_query(sqla_query.sqla_query)
self.assertIn(filter_.expected, sql)

def test_incorrect_jinja_syntax_raises_correct_exception(self):
query_obj = {
"granularity": None,
"from_dttm": None,
"to_dttm": None,
"groupby": ["user"],
"metrics": [],
"is_timeseries": False,
"filter": [],
"extras": {},
}

# Table with Jinja callable.
table = SqlaTable(
table_name="test_table",
sql="SELECT '{{ abcd xyz + 1 ASDF }}' as user",
database=get_example_database(),
)
# TODO(villebro): make it work with presto
if get_example_database().backend != "presto":
with pytest.raises(QueryObjectValidationError):
table.get_sqla_query(**query_obj)

0 comments on commit 08358d6

Please sign in to comment.