Skip to content

Commit

Permalink
fix(mssql): apply limit and set alias for functions (apache#9644)
Browse files Browse the repository at this point in the history
  • Loading branch information
dpgaspar authored Apr 27, 2020
1 parent 5e4c291 commit 516bdf6
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 3 deletions.
14 changes: 13 additions & 1 deletion superset/db_engine_specs/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,20 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
import re
from datetime import datetime
from typing import Any, List, Optional, Tuple
from typing import Any, List, Optional, Tuple, TYPE_CHECKING

from sqlalchemy.types import String, TypeEngine, UnicodeText

from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
from superset.sql_parse import ParsedQuery

if TYPE_CHECKING:
from superset.models.core import Database # pylint: disable=unused-import

logger = logging.getLogger(__name__)


class MssqlEngineSpec(BaseEngineSpec):
Expand Down Expand Up @@ -76,3 +83,8 @@ def get_sqla_column_type(cls, type_: str) -> Optional[TypeEngine]:
if regex.match(type_):
return sqla_type
return None

@classmethod
def apply_limit_to_sql(cls, sql: str, limit: int, database: "Database") -> str:
new_sql = ParsedQuery(sql).set_alias()
return super().apply_limit_to_sql(new_sql, limit, database)
45 changes: 44 additions & 1 deletion superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,14 @@
from typing import List, Optional, Set

import sqlparse
from sqlparse.sql import Identifier, IdentifierList, remove_quotes, Token, TokenList
from sqlparse.sql import (
Function,
Identifier,
IdentifierList,
remove_quotes,
Token,
TokenList,
)
from sqlparse.tokens import Keyword, Name, Punctuation, String, Whitespace
from sqlparse.utils import imt

Expand Down Expand Up @@ -247,3 +254,39 @@ def set_or_update_query_limit(self, new_limit: int) -> str:
for i in statement.tokens:
str_res += str(i.value)
return str_res

def set_alias(self) -> str:
"""
Returns a new query string where all functions have alias.
This is particularly necessary for MSSQL engines.
:return: String with new aliased SQL query
"""
new_sql = ""
changed_counter = 1
for token in self._parsed[0].tokens:
# Identifier list (list of columns)
if isinstance(token, IdentifierList) and token.ttype is None:
for i, identifier in enumerate(token.get_identifiers()):
# Functions are anonymous on MSSQL
if isinstance(identifier, Function) and not identifier.has_alias():
identifier.value = (
f"{identifier.value} AS"
f" {identifier.get_real_name()}_{changed_counter}"
)
changed_counter += 1
new_sql += str(identifier.value)
# If not last identifier
if i != len(list(token.get_identifiers())) - 1:
new_sql += ", "
# Just a lonely function?
elif isinstance(token, Function) and token.ttype is None:
if not token.has_alias():
token.value = (
f"{token.value} AS {token.get_real_name()}_{changed_counter}"
)
new_sql += str(token.value)
# Nothing to change, assemble what we have
else:
new_sql += str(token.value)
return new_sql
64 changes: 63 additions & 1 deletion tests/db_engine_specs/mssql_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,18 @@
# specific language governing permissions and limitations
# under the License.
import unittest.mock as mock
from typing import Optional

from sqlalchemy import column, table
from sqlalchemy.dialects import mssql
from sqlalchemy.dialects.mssql import DATE, NTEXT, NVARCHAR, TEXT, VARCHAR
from sqlalchemy.sql import select
from sqlalchemy.sql import select, Select
from sqlalchemy.types import String, UnicodeText

from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.mssql import MssqlEngineSpec
from superset.extensions import db
from superset.models.core import Database
from tests.db_engine_specs.base_tests import DbEngineSpecTestCase


Expand Down Expand Up @@ -94,6 +97,65 @@ def test_convert_dttm(self):
for actual, expected in test_cases:
self.assertEqual(actual, expected)

def test_apply_limit(self):
def compile_sqla_query(qry: Select, schema: Optional[str] = None) -> str:
return str(
qry.compile(
dialect=mssql.dialect(), compile_kwargs={"literal_binds": True}
)
)

database = Database(
database_name="mssql_test",
sqlalchemy_uri="mssql+pymssql://sa:Password_123@localhost:1433/msdb",
)
db.session.add(database)
db.session.commit()

with mock.patch.object(database, "compile_sqla_query", new=compile_sqla_query):
test_sql = "SELECT COUNT(*) FROM FOO_TABLE"

limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, database)

expected_sql = (
"SELECT TOP 1000 * \n"
"FROM (SELECT COUNT(*) AS COUNT_1 FROM FOO_TABLE) AS inner_qry"
)
self.assertEqual(expected_sql, limited_sql)

test_sql = "SELECT COUNT(*), SUM(id) FROM FOO_TABLE"
limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, database)

expected_sql = (
"SELECT TOP 1000 * \n"
"FROM (SELECT COUNT(*) AS COUNT_1, SUM(id) AS SUM_2 FROM FOO_TABLE) "
"AS inner_qry"
)
self.assertEqual(expected_sql, limited_sql)

test_sql = "SELECT COUNT(*), FOO_COL1 FROM FOO_TABLE GROUP BY FOO_COL1"
limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, database)

expected_sql = (
"SELECT TOP 1000 * \n"
"FROM (SELECT COUNT(*) AS COUNT_1, "
"FOO_COL1 FROM FOO_TABLE GROUP BY FOO_COL1)"
" AS inner_qry"
)
self.assertEqual(expected_sql, limited_sql)

test_sql = "SELECT COUNT(*), COUNT(*) FROM FOO_TABLE"
limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, database)
expected_sql = (
"SELECT TOP 1000 * \n"
"FROM (SELECT COUNT(*) AS COUNT_1, COUNT(*) AS COUNT_2 FROM FOO_TABLE)"
" AS inner_qry"
)
self.assertEqual(expected_sql, limited_sql)

db.session.delete(database)
db.session.commit()

@mock.patch.object(
MssqlEngineSpec, "pyodbc_rows_to_tuples", return_value="converted"
)
Expand Down

0 comments on commit 516bdf6

Please sign in to comment.