Skip to content

Commit

Permalink
[sql] Correct SQL parameter formatting (apache#5178)
Browse files Browse the repository at this point in the history
  • Loading branch information
john-bodley authored Jul 21, 2018
1 parent 6e7b587 commit 7fcc2af
Show file tree
Hide file tree
Showing 9 changed files with 138 additions and 30 deletions.
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ ignored-modules=numpy,pandas,alembic.op,sqlalchemy,alembic.context,flask_appbuil
# List of class names for which member attributes should not be checked (useful
# for classes with dynamically set attributes). This supports the use of
# qualified names.
ignored-classes=optparse.Values,thread._local,_thread._local,sqlalchemy.orm.scoping.scoped_session
ignored-classes=contextlib.closing,optparse.Values,thread._local,_thread._local,sqlalchemy.orm.scoping.scoped_session

# List of members which are set dynamically and missed by pylint inference
# system, and so shouldn't trigger E1101 when accessed. Python regular
Expand Down
9 changes: 1 addition & 8 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from flask_appbuilder import Model
from flask_babel import lazy_gettext as _
import pandas as pd
import six
import sqlalchemy as sa
from sqlalchemy import (
and_, asc, Boolean, Column, DateTime, desc, ForeignKey, Integer, or_,
Expand Down Expand Up @@ -427,14 +426,8 @@ def get_template_processor(self, **kwargs):
table=self, database=self.database, **kwargs)

def get_query_str(self, query_obj):
engine = self.database.get_sqla_engine()
qry = self.get_sqla_query(**query_obj)
sql = six.text_type(
qry.compile(
engine,
compile_kwargs={'literal_binds': True},
),
)
sql = self.database.compile_sqla_query(qry)
logging.info(sql)
sql = sqlparse.format(sql, reindent=True)
if query_obj['is_prequery']:
Expand Down
12 changes: 8 additions & 4 deletions superset/db_engine_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ class BaseEngineSpec(object):
"""Abstract class for database engine specific configurations"""

engine = 'base' # str as defined in sqlalchemy.engine.engine
cursor_execute_kwargs = {}
time_grains = tuple()
time_groupby_inline = False
limit_method = LimitMethod.FORCE_LIMIT
Expand Down Expand Up @@ -331,6 +330,10 @@ def get_normalized_column_names(cls, cursor_description):
def normalize_column_name(column_name):
return column_name

@staticmethod
def execute(cursor, query, async=False):
cursor.execute(query)


class PostgresBaseEngineSpec(BaseEngineSpec):
""" Abstract class for Postgres 'like' databases """
Expand Down Expand Up @@ -558,7 +561,6 @@ def get_table_names(cls, schema, inspector):

class MySQLEngineSpec(BaseEngineSpec):
engine = 'mysql'
cursor_execute_kwargs = {'args': {}}
time_grains = (
Grain('Time Column', _('Time Column'), '{col}', None),
Grain('second', _('second'), 'DATE_ADD(DATE({col}), '
Expand Down Expand Up @@ -639,7 +641,6 @@ def extract_error_message(cls, e):

class PrestoEngineSpec(BaseEngineSpec):
engine = 'presto'
cursor_execute_kwargs = {'parameters': None}

time_grains = (
Grain('Time Column', _('Time Column'), '{col}', None),
Expand Down Expand Up @@ -938,7 +939,6 @@ class HiveEngineSpec(PrestoEngineSpec):
"""Reuses PrestoEngineSpec functionality."""

engine = 'hive'
cursor_execute_kwargs = {'async': True}

# Scoping regex at class level to avoid recompiling
# 17/02/07 19:36:38 INFO ql.Driver: Total jobs = 5
Expand Down Expand Up @@ -1230,6 +1230,10 @@ def get_configuration_for_impersonation(cls, uri, impersonate_user, username):
configuration['hive.server2.proxy.user'] = username
return configuration

@staticmethod
def execute(cursor, query, async=False):
cursor.execute(query, async=async)


class MssqlEngineSpec(BaseEngineSpec):
engine = 'mssql'
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""remove double percents
Revision ID: 4451805bbaa1
Revises: afb7730f6a9c
Create Date: 2018-06-13 10:20:35.846744
"""

# revision identifiers, used by Alembic.
revision = '4451805bbaa1'
down_revision = 'bddc498dd179'


from alembic import op
import json
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Column, create_engine, ForeignKey, Integer, String, Text

from superset import db

Base = declarative_base()


class Slice(Base):
__tablename__ = 'slices'

id = Column(Integer, primary_key=True)
datasource_id = Column(Integer, ForeignKey('tables.id'))
datasource_type = Column(String(200))
params = Column(Text)


class Table(Base):
__tablename__ = 'tables'

id = Column(Integer, primary_key=True)
database_id = Column(Integer, ForeignKey('dbs.id'))


class Database(Base):
__tablename__ = 'dbs'

id = Column(Integer, primary_key=True)
sqlalchemy_uri = Column(String(1024))


def replace(source, target):
bind = op.get_bind()
session = db.Session(bind=bind)

query = (
session.query(Slice, Database)
.join(Table)
.join(Database)
.filter(Slice.datasource_type == 'table')
.all()
)

for slc, database in query:
try:
engine = create_engine(database.sqlalchemy_uri)

if engine.dialect.identifier_preparer._double_percents:
params = json.loads(slc.params)

if 'adhoc_filters' in params:
for filt in params['adhoc_filters']:
if 'sqlExpression' in filt:
filt['sqlExpression'] = (
filt['sqlExpression'].replace(source, target)
)

slc.params = json.dumps(params, sort_keys=True)
except Exception:
pass

session.commit()
session.close()


def upgrade():
replace('%%', '%')


def downgrade():
replace('%', '%%')
43 changes: 30 additions & 13 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from __future__ import print_function
from __future__ import unicode_literals

from contextlib import closing
from copy import copy, deepcopy
from datetime import datetime
import functools
Expand All @@ -19,6 +20,7 @@
from future.standard_library import install_aliases
import numpy
import pandas as pd
import six
import sqlalchemy as sqla
from sqlalchemy import (
Boolean, Column, create_engine, DateTime, ForeignKey, Integer,
Expand Down Expand Up @@ -749,12 +751,7 @@ def get_quoter(self):

def get_df(self, sql, schema):
sqls = [str(s).strip().strip(';') for s in sqlparse.parse(sql)]
eng = self.get_sqla_engine(schema=schema)

for i in range(len(sqls) - 1):
eng.execute(sqls[i])

df = pd.read_sql_query(sqls[-1], eng)
engine = self.get_sqla_engine(schema=schema)

def needs_conversion(df_series):
if df_series.empty:
Expand All @@ -763,15 +760,35 @@ def needs_conversion(df_series):
return True
return False

for k, v in df.dtypes.items():
if v.type == numpy.object_ and needs_conversion(df[k]):
df[k] = df[k].apply(utils.json_dumps_w_dates)
return df
with closing(engine.raw_connection()) as conn:
with closing(conn.cursor()) as cursor:
for sql in sqls:
self.db_engine_spec.execute(cursor, sql)
df = pd.DataFrame.from_records(
data=list(cursor.fetchall()),
columns=[col_desc[0] for col_desc in cursor.description],
coerce_float=True,
)

for k, v in df.dtypes.items():
if v.type == numpy.object_ and needs_conversion(df[k]):
df[k] = df[k].apply(utils.json_dumps_w_dates)
return df

def compile_sqla_query(self, qry, schema=None):
eng = self.get_sqla_engine(schema=schema)
compiled = qry.compile(eng, compile_kwargs={'literal_binds': True})
return '{}'.format(compiled)
engine = self.get_sqla_engine(schema=schema)

sql = six.text_type(
qry.compile(
engine,
compile_kwargs={'literal_binds': True},
),
)

if engine.dialect.identifier_preparer._double_percents:
sql = sql.replace('%%', '%')

return sql

def select_star(
self, table_name, schema=None, limit=100, show_cols=False,
Expand Down
3 changes: 1 addition & 2 deletions superset/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,7 @@ def handle_error(msg):
cursor = conn.cursor()
logging.info('Running query: \n{}'.format(executed_sql))
logging.info(query.executed_sql)
cursor.execute(query.executed_sql,
**db_engine_spec.cursor_execute_kwargs)
db_engine_spec.execute(cursor, query.executed_sql, async=True)
logging.info('Handling cursor')
db_engine_spec.handle_cursor(cursor, query, session)
logging.info('Fetching data: {}'.format(query.to_dict()))
Expand Down
9 changes: 9 additions & 0 deletions tests/core_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,15 @@ def test_csv_endpoint(self):
expected_data = csv.reader(
io.StringIO('first_name,last_name\nadmin, user\n'))

sql = "SELECT first_name FROM ab_user WHERE first_name LIKE '%admin%'"
client_id = '{}'.format(random.getrandbits(64))[:10]
self.run_sql(sql, client_id, raise_on_error=True)

resp = self.get_resp('/superset/csv/{}'.format(client_id))
data = csv.reader(io.StringIO(resp))
expected_data = csv.reader(
io.StringIO('first_name\nadmin\n'))

self.assertEqual(list(expected_data), list(data))
self.logout()

Expand Down
2 changes: 1 addition & 1 deletion tests/sqllab_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def test_sqllab_viz(self):
'sql': """\
SELECT viz_type, count(1) as ccount
FROM slices
WHERE viz_type LIKE '%%a%%'
WHERE viz_type LIKE '%a%'
GROUP BY viz_type""",
'dbId': 1,
}
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ setenv =
SUPERSET_CONFIG = tests.superset_test_config
SUPERSET_HOME = {envtmpdir}
py27-mysql: SUPERSET__SQLALCHEMY_DATABASE_URI = mysql://mysqluser:mysqluserpassword@localhost/superset?charset=utf8
py34-mysql: SUPERSET__SQLALCHEMY_DATABASE_URI = mysql://mysqluser:mysqluserpassword@localhost/superset
py{34,36}-mysql: SUPERSET__SQLALCHEMY_DATABASE_URI = mysql://mysqluser:mysqluserpassword@localhost/superset
py{27,34,36}-postgres: SUPERSET__SQLALCHEMY_DATABASE_URI = postgresql+psycopg2://postgresuser:pguserpassword@localhost/superset
py{27,34,36}-sqlite: SUPERSET__SQLALCHEMY_DATABASE_URI = sqlite:////{envtmpdir}/superset.db
whitelist_externals =
Expand Down

0 comments on commit 7fcc2af

Please sign in to comment.