Skip to content

Commit

Permalink
fix(presto/trino): Ensure get_table_names only returns real tables (a…
Browse files Browse the repository at this point in the history
  • Loading branch information
john-bodley authored Nov 9, 2022
1 parent 53ed8f2 commit 9f7bd1e
Show file tree
Hide file tree
Showing 10 changed files with 118 additions and 109 deletions.
1 change: 1 addition & 0 deletions UPDATING.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ assists people when migrating to a new version.
- [21002](https://github.com/apache/superset/pull/21002): Support Python 3.10 and bump pandas 1.4 and pyarrow 6.
- [21163](https://github.com/apache/superset/pull/21163): When `GENERIC_CHART_AXES` feature flags set to `True`, the Time Grain control will move below the X-Axis control.
- [21284](https://github.com/apache/superset/pull/21284): The non-functional `MAX_TABLE_NAMES` config key has been removed.
- [21794](https://github.com/apache/superset/pull/21794): Deprecates the undocumented `PRESTO_SPLIT_VIEWS_FROM_TABLES` feature flag. Now for Presto, like other engines, only physical tables are treated as tables.

### Breaking Changes

Expand Down
2 changes: 0 additions & 2 deletions requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,6 @@ geopy==2.2.0
# via apache-superset
graphlib-backport==1.0.3
# via apache-superset
greenlet==1.1.2
# via sqlalchemy
gunicorn==20.1.0
# via apache-superset
hashids==1.3.1
Expand Down
2 changes: 2 additions & 0 deletions requirements/docker.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# -r requirements/docker.in
gevent==21.8.0
# via -r requirements/docker.in
greenlet==1.1.3.post0
# via gevent
psycopg2-binary==2.9.1
# via apache-superset
zope-event==4.5.0
Expand Down
2 changes: 1 addition & 1 deletion requirements/testing.txt
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ rsa==4.7.2
# via google-auth
statsd==3.3.0
# via -r requirements/testing.in
trino==0.315.0
trino==0.319.0
# via apache-superset
typing-inspect==0.7.1
# via libcst
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def get_git_sha() -> str:
"pinot": ["pinotdb>=0.3.3, <0.4"],
"postgres": ["psycopg2-binary==2.9.1"],
"presto": ["pyhive[presto]>=0.6.5"],
"trino": ["trino>=0.313.0"],
"trino": ["trino>=0.319.0"],
"prophet": ["prophet>=1.0.1, <1.1", "pystan<3.0"],
"redshift": ["sqlalchemy-redshift>=0.8.1, < 0.9"],
"rockset": ["rockset>=0.8.10, <0.9"],
Expand Down
4 changes: 2 additions & 2 deletions superset/databases/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import logging
from datetime import datetime
from io import BytesIO
from typing import Any, Dict, List, Optional
from typing import Any, cast, Dict, List, Optional
from zipfile import is_zipfile, ZipFile

from flask import request, Response, send_file
Expand Down Expand Up @@ -611,7 +611,7 @@ def table_extra_metadata(
self.incr_stats("init", self.table_metadata.__name__)

parsed_schema = parse_js_uri_path_item(schema_name, eval_undefined=True)
table_name = parse_js_uri_path_item(table_name) # type: ignore
table_name = cast(str, parse_js_uri_path_item(table_name))
payload = database.db_engine_spec.extra_table_metadata(
database, table_name, parsed_schema
)
Expand Down
28 changes: 18 additions & 10 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,13 +1018,17 @@ def get_table_names( # pylint: disable=unused-argument
schema: Optional[str],
) -> List[str]:
"""
Get all tables from schema
Get all the real table names within the specified schema.
:param database: The database to get info
:param inspector: SqlAlchemy inspector
:param schema: Schema to inspect. If omitted, uses default schema for database
:return: All tables in schema
Per the SQLAlchemy definition if the schema is omitted the database’s default
schema is used, however some dialects infer the request as schema agnostic.
:param database: The database to inspect
:param inspector: The SQLAlchemy inspector
:param schema: The schema to inspect
:returns: The physical table names
"""

try:
tables = inspector.get_table_names(schema)
except Exception as ex:
Expand All @@ -1042,13 +1046,17 @@ def get_view_names( # pylint: disable=unused-argument
schema: Optional[str],
) -> List[str]:
"""
Get all views from schema
Get all the view names within the specified schema.
:param database: The database to get info
:param inspector: SqlAlchemy inspector
:param schema: Schema name. If omitted, uses default schema for database
:return: All views in schema
Per the SQLAlchemy definition if the schema is omitted the database’s default
schema is used, however some dialects infer the request as schema agnostic.
:param database: The database to inspect
:param inspector: The SQLAlchemy inspector
:param schema: The schema to inspect
:returns: The view names
"""

try:
views = inspector.get_view_names(schema)
except Exception as ex:
Expand Down
82 changes: 60 additions & 22 deletions superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@

import logging
import re
import textwrap
import time
from abc import ABCMeta
from collections import defaultdict, deque
from contextlib import closing
from datetime import datetime
from distutils.version import StrictVersion
from textwrap import dedent
from typing import Any, cast, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING, Union
from urllib import parse

Expand Down Expand Up @@ -392,46 +392,84 @@ def update_impersonation_config(

@classmethod
def get_table_names(
cls, database: Database, inspector: Inspector, schema: Optional[str]
cls,
database: Database,
inspector: Inspector,
schema: Optional[str],
) -> List[str]:
tables = super().get_table_names(database, inspector, schema)
if not is_feature_enabled("PRESTO_SPLIT_VIEWS_FROM_TABLES"):
return tables
"""
Get all the real table names within the specified schema.
Per the SQLAlchemy definition if the schema is omitted the database’s default
schema is used, however some dialects infer the request as schema agnostic.
views = set(cls.get_view_names(database, inspector, schema))
actual_tables = set(tables) - views
return list(actual_tables)
Note that PyHive's Hive and Presto SQLAlchemy dialects do not adhere to the
specification where the `get_table_names` method returns both real tables and
views. Futhermore the dialects wrongfully infer the request as schema agnostic
when the schema is omitted.
:param database: The database to inspect
:param inspector: The SQLAlchemy inspector
:param schema: The schema to inspect
:returns: The physical table names
"""

return sorted(
list(
set(super().get_table_names(database, inspector, schema))
- set(cls.get_view_names(database, inspector, schema))
)
)

@classmethod
def get_view_names(
cls, database: Database, inspector: Inspector, schema: Optional[str]
cls,
database: Database,
inspector: Inspector,
schema: Optional[str],
) -> List[str]:
"""Returns an empty list
"""
Get all the view names within the specified schema.
get_table_names() function returns all table names and view names,
and get_view_names() is not implemented in sqlalchemy_presto.py
https://github.com/dropbox/PyHive/blob/e25fc8440a0686bbb7a5db5de7cb1a77bdb4167a/pyhive/sqlalchemy_presto.py
Per the SQLAlchemy definition if the schema is omitted the database’s default
schema is used, however some dialects infer the request as schema agnostic.
Note that PyHive's Hive and Presto SQLAlchemy dialects do not implement the
`get_view_names` method. To ensure consistency with the `get_table_names` method
the request is deemed schema agnostic when the schema is omitted.
:param database: The database to inspect
:param inspector: The SQLAlchemy inspector
:param schema: The schema to inspect
:returns: The view names
"""
if not is_feature_enabled("PRESTO_SPLIT_VIEWS_FROM_TABLES"):
return []

if schema:
sql = (
"SELECT table_name FROM information_schema.views "
"WHERE table_schema=%(schema)s"
)
sql = dedent(
"""
SELECT table_name FROM information_schema.tables
WHERE table_schema = %(schema)s
AND table_type = 'VIEW'
"""
).strip()
params = {"schema": schema}
else:
sql = "SELECT table_name FROM information_schema.views"
sql = dedent(
"""
SELECT table_name FROM information_schema.tables
WHERE table_type = 'VIEW'
"""
).strip()
params = {}

engine = cls.get_engine(database, schema=schema)

with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
cursor.execute(sql, params)
results = cursor.fetchall()

return [row[0] for row in results]
return sorted([row[0] for row in results])

@classmethod
def _create_column_info(
Expand Down Expand Up @@ -1087,7 +1125,7 @@ def _partition_query( # pylint: disable=too-many-arguments,too-many-locals
else f"SHOW PARTITIONS FROM {table_name}"
)

sql = textwrap.dedent(
sql = dedent(
f"""\
{partition_select_clause}
{where_clause}
Expand Down
4 changes: 0 additions & 4 deletions tests/integration_tests/db_engine_specs/hive_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,6 @@ def test_hive_error_msg():
)


def test_hive_get_view_names_return_empty_list(): # pylint: disable=invalid-name
assert HiveEngineSpec.get_view_names(mock.ANY, mock.ANY, mock.ANY) == []


def test_convert_dttm():
dttm = datetime.strptime("2019-01-02 03:04:05.678900", "%Y-%m-%d %H:%M:%S.%f")
assert HiveEngineSpec.convert_dttm("DATE", dttm) == "CAST('2019-01-02' AS DATE)"
Expand Down
100 changes: 33 additions & 67 deletions tests/integration_tests/db_engine_specs/presto_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from collections import namedtuple
from textwrap import dedent
from unittest import mock, skipUnless

import pandas as pd
Expand All @@ -33,49 +34,47 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
def test_get_datatype_presto(self):
self.assertEqual("STRING", PrestoEngineSpec.get_datatype("string"))

def test_presto_get_view_names_return_empty_list(
self,
): # pylint: disable=invalid-name
self.assertEqual(
[], PrestoEngineSpec.get_view_names(mock.ANY, mock.ANY, mock.ANY)
)

@mock.patch("superset.db_engine_specs.presto.is_feature_enabled")
def test_get_view_names(self, mock_is_feature_enabled):
mock_is_feature_enabled.return_value = True
mock_execute = mock.MagicMock()
mock_fetchall = mock.MagicMock(return_value=[["a", "b,", "c"], ["d", "e"]])
def test_get_view_names_with_schema(self):
database = mock.MagicMock()
mock_execute = mock.MagicMock()
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute = (
mock_execute
)
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.fetchall = (
mock_fetchall
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.fetchall = mock.MagicMock(
return_value=[["a", "b,", "c"], ["d", "e"]]
)
result = PrestoEngineSpec.get_view_names(database, mock.Mock(), None)
schema = "schema"
result = PrestoEngineSpec.get_view_names(database, mock.Mock(), schema)
mock_execute.assert_called_once_with(
"SELECT table_name FROM information_schema.views", {}
dedent(
"""
SELECT table_name FROM information_schema.tables
WHERE table_schema = %(schema)s
AND table_type = 'VIEW'
"""
).strip(),
{"schema": schema},
)
assert result == ["a", "d"]

@mock.patch("superset.db_engine_specs.presto.is_feature_enabled")
def test_get_view_names_with_schema(self, mock_is_feature_enabled):
mock_is_feature_enabled.return_value = True
mock_execute = mock.MagicMock()
mock_fetchall = mock.MagicMock(return_value=[["a", "b,", "c"], ["d", "e"]])
def test_get_view_names_without_schema(self):
database = mock.MagicMock()
mock_execute = mock.MagicMock()
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute = (
mock_execute
)
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.fetchall = (
mock_fetchall
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.fetchall = mock.MagicMock(
return_value=[["a", "b,", "c"], ["d", "e"]]
)
schema = "schema"
result = PrestoEngineSpec.get_view_names(database, mock.Mock(), schema)
result = PrestoEngineSpec.get_view_names(database, mock.Mock(), None)
mock_execute.assert_called_once_with(
"SELECT table_name FROM information_schema.views "
"WHERE table_schema=%(schema)s",
{"schema": schema},
dedent(
"""
SELECT table_name FROM information_schema.tables
WHERE table_type = 'VIEW'
"""
).strip(),
{},
)
assert result == ["a", "d"]

Expand Down Expand Up @@ -663,50 +662,17 @@ def test_get_sqla_column_type(self):
sqla_type = PrestoEngineSpec.get_sqla_column_type(None)
assert sqla_type is None

@mock.patch(
"superset.utils.feature_flag_manager.FeatureFlagManager.is_feature_enabled"
)
@mock.patch("superset.db_engine_specs.base.BaseEngineSpec.get_table_names")
@mock.patch("superset.db_engine_specs.presto.PrestoEngineSpec.get_view_names")
def test_get_table_names_no_split_views_from_tables(
self, mock_get_view_names, mock_get_table_names, mock_is_feature_enabled
):
mock_get_view_names.return_value = ["view1", "view2"]
table_names = ["table1", "table2", "view1", "view2"]
mock_get_table_names.return_value = table_names
mock_is_feature_enabled.return_value = False
tables = PrestoEngineSpec.get_table_names(mock.Mock(), mock.Mock(), None)
assert tables == table_names

@mock.patch(
"superset.utils.feature_flag_manager.FeatureFlagManager.is_feature_enabled"
)
@mock.patch("superset.db_engine_specs.base.BaseEngineSpec.get_table_names")
@mock.patch("superset.db_engine_specs.presto.PrestoEngineSpec.get_view_names")
def test_get_table_names_split_views_from_tables(
self, mock_get_view_names, mock_get_table_names, mock_is_feature_enabled
def test_get_table_names(
self,
mock_get_view_names,
mock_get_table_names,
):
mock_get_view_names.return_value = ["view1", "view2"]
table_names = ["table1", "table2", "view1", "view2"]
mock_get_table_names.return_value = table_names
mock_is_feature_enabled.return_value = True
tables = PrestoEngineSpec.get_table_names(mock.Mock(), mock.Mock(), None)
assert sorted(tables) == sorted(table_names)

@mock.patch(
"superset.utils.feature_flag_manager.FeatureFlagManager.is_feature_enabled"
)
@mock.patch("superset.db_engine_specs.base.BaseEngineSpec.get_table_names")
@mock.patch("superset.db_engine_specs.presto.PrestoEngineSpec.get_view_names")
def test_get_table_names_split_views_from_tables_no_tables(
self, mock_get_view_names, mock_get_table_names, mock_is_feature_enabled
):
mock_get_view_names.return_value = []
table_names = []
mock_get_table_names.return_value = table_names
mock_is_feature_enabled.return_value = True
mock_get_table_names.return_value = ["table1", "table2", "view1", "view2"]
tables = PrestoEngineSpec.get_table_names(mock.Mock(), mock.Mock(), None)
assert tables == []
assert tables == ["table1", "table2"]

def test_get_full_name(self):
names = [
Expand Down

0 comments on commit 9f7bd1e

Please sign in to comment.