Skip to content

Commit

Permalink
Add hive to superset + monkey patch the pyhive (apache#2134)
Browse files Browse the repository at this point in the history
* Initial hive implementation

* Fix select star query for hive.

* Exclude generated code.

* Address code coverage and linting.

* Exclude generated code from coveralls.

* Fix lint errors

* Move TCLIService to it's own repo.

* Address comments

* Implement special postgres case,
  • Loading branch information
bkyryliuk authored Mar 7, 2017
1 parent ad4a950 commit 9114d86
Show file tree
Hide file tree
Showing 9 changed files with 504 additions and 218 deletions.
334 changes: 305 additions & 29 deletions superset/db_engine_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,19 @@
from __future__ import unicode_literals

from collections import namedtuple, defaultdict
from flask_babel import lazy_gettext as _
from superset import utils

import inspect
import re
import sqlparse
import textwrap
import time

from superset import cache_util
from sqlalchemy import select
from sqlalchemy.sql import text
from superset.utils import SupersetTemplateException
from flask_babel import lazy_gettext as _

Grain = namedtuple('Grain', 'name label function')

Expand All @@ -37,9 +42,16 @@ class LimitMethod(object):

class BaseEngineSpec(object):
engine = 'base' # str as defined in sqlalchemy.engine.engine
cursor_execute_kwargs = {}
time_grains = tuple()
limit_method = LimitMethod.FETCH_MANY

@classmethod
def fetch_data(cls, cursor, limit):
if cls.limit_method == LimitMethod.FETCH_MANY:
return cursor.fetchmany(limit)
return cursor.fetchall()

@classmethod
def epoch_to_dttm(cls):
raise NotImplementedError()
Expand Down Expand Up @@ -106,6 +118,40 @@ def sql_preprocessor(cls, sql):
"""
return sql

@classmethod
def patch(cls):
pass

@classmethod
def where_latest_partition(
cls, table_name, schema, database, qry, columns=None):
return False

@classmethod
def select_star(cls, my_db, table_name, schema=None, limit=100,
show_cols=False, indent=True):
fields = '*'
table = my_db.get_table(table_name, schema=schema)
if show_cols:
fields = [my_db.get_quoter()(c.name) for c in table.columns]
full_table_name = table_name
if schema:
full_table_name = schema + '.' + table_name
qry = select(fields)
if limit:
qry = qry.limit(limit)
partition_query = cls.where_latest_partition(
table_name, schema, my_db, qry, columns=table.columns)
# if not partition_query condition fails.
if partition_query == False: # noqa
qry = qry.select_from(text(full_table_name))
else:
qry = partition_query
sql = my_db.compile_sqla_query(qry)
if indent:
sql = sqlparse.format(sql, reindent=True)
return sql


class PostgresEngineSpec(BaseEngineSpec):
engine = 'postgresql'
Expand All @@ -122,6 +168,14 @@ class PostgresEngineSpec(BaseEngineSpec):
Grain("year", _('year'), "DATE_TRUNC('year', {col})"),
)

@classmethod
def fetch_data(cls, cursor, limit):
if not cursor.description:
return []
if cls.limit_method == LimitMethod.FETCH_MANY:
return cursor.fetchmany(limit)
return cursor.fetchall()

@classmethod
def epoch_to_dttm(cls):
return "(timestamp 'epoch' + {col} * interval '1 second')"
Expand Down Expand Up @@ -235,27 +289,6 @@ def convert_dttm(cls, target_type, dttm):
def epoch_to_dttm(cls):
return "from_unixtime({col})"

@staticmethod
def show_partition_pql(
table_name, schema_name=None, order_by=None, limit=100):
if schema_name:
table_name = schema_name + '.' + table_name
order_by = order_by or []
order_by_clause = ''
if order_by:
order_by_clause = "ORDER BY " + ', '.join(order_by) + " DESC"

limit_clause = ''
if limit:
limit_clause = "LIMIT {}".format(limit)

return textwrap.dedent("""\
SHOW PARTITIONS
FROM {table_name}
{order_by_clause}
{limit_clause}
""").format(**locals())

@classmethod
@cache_util.memoized_func(
timeout=600,
Expand Down Expand Up @@ -284,16 +317,14 @@ def extra_table_metadata(cls, database, table_name, schema_name):
if not indexes:
return {}
cols = indexes[0].get('column_names', [])
pql = cls.show_partition_pql(table_name, schema_name, cols)
df = database.get_df(pql, schema_name)
latest_part = df.to_dict(orient='records')[0] if not df.empty else None

partition_query = cls.show_partition_pql(table_name, schema_name, cols)
pql = cls._partition_query(table_name, schema_name, cols)
col_name, latest_part = cls.latest_partition(
table_name, schema_name, database)
return {
'partitions': {
'cols': cols,
'latest': latest_part,
'partitionQuery': partition_query,
'latest': {col_name: latest_part},
'partitionQuery': pql,
}
}

Expand Down Expand Up @@ -332,6 +363,251 @@ def extract_error_message(cls, e):
)
return utils.error_msg_from_exception(e)

@classmethod
def _partition_query(
cls, table_name, limit=0, order_by=None, filters=None):
"""Returns a partition query
:param table_name: the name of the table to get partitions from
:type table_name: str
:param limit: the number of partitions to be returned
:type limit: int
:param order_by: a list of tuples of field name and a boolean
that determines if that field should be sorted in descending
order
:type order_by: list of (str, bool) tuples
:param filters: a list of filters to apply
:param filters: dict of field name and filter value combinations
"""
limit_clause = "LIMIT {}".format(limit) if limit else ''
order_by_clause = ''
if order_by:
l = []
for field, desc in order_by:
l.append(field + ' DESC' if desc else '')
order_by_clause = 'ORDER BY ' + ', '.join(l)

where_clause = ''
if filters:
l = []
for field, value in filters.items():
l.append("{field} = '{value}'".format(**locals()))
where_clause = 'WHERE ' + ' AND '.join(l)

sql = textwrap.dedent("""\
SHOW PARTITIONS FROM {table_name}
{where_clause}
{order_by_clause}
{limit_clause}
""").format(**locals())
return sql

@classmethod
def _latest_partition_from_df(cls, df):
return df.to_records(index=False)[0][0]

@classmethod
def latest_partition(cls, table_name, schema, database):
"""Returns col name and the latest (max) partition value for a table
:param table_name: the name of the table
:type table_name: str
:param schema: schema / database / namespace
:type schema: str
:param database: database query will be run against
:type database: models.Database
>>> latest_partition('foo_table')
'2018-01-01'
"""
indexes = database.get_indexes(table_name, schema)
if len(indexes[0]['column_names']) < 1:
raise SupersetTemplateException(
"The table should have one partitioned field")
elif len(indexes[0]['column_names']) > 1:
raise SupersetTemplateException(
"The table should have a single partitioned field "
"to use this function. You may want to use "
"`presto.latest_sub_partition`")
part_field = indexes[0]['column_names'][0]
sql = cls._partition_query(table_name, 1, [(part_field, True)])
df = database.get_df(sql, schema)
return part_field, cls._latest_partition_from_df(df)

@classmethod
def latest_sub_partition(cls, table_name, schema, database, **kwargs):
"""Returns the latest (max) partition value for a table
A filtering criteria should be passed for all fields that are
partitioned except for the field to be returned. For example,
if a table is partitioned by (``ds``, ``event_type`` and
``event_category``) and you want the latest ``ds``, you'll want
to provide a filter as keyword arguments for both
``event_type`` and ``event_category`` as in
``latest_sub_partition('my_table',
event_category='page', event_type='click')``
:param table_name: the name of the table, can be just the table
name or a fully qualified table name as ``schema_name.table_name``
:type table_name: str
:param schema: schema / database / namespace
:type schema: str
:param database: database query will be run against
:type database: models.Database
:param kwargs: keyword arguments define the filtering criteria
on the partition list. There can be many of these.
:type kwargs: str
>>> latest_sub_partition('sub_partition_table', event_type='click')
'2018-01-01'
"""
indexes = database.get_indexes(table_name, schema)
part_fields = indexes[0]['column_names']
for k in kwargs.keys():
if k not in k in part_fields:
msg = "Field [{k}] is not part of the portioning key"
raise SupersetTemplateException(msg)
if len(kwargs.keys()) != len(part_fields) - 1:
msg = (
"A filter needs to be specified for {} out of the "
"{} fields."
).format(len(part_fields)-1, len(part_fields))
raise SupersetTemplateException(msg)

for field in part_fields:
if field not in kwargs.keys():
field_to_return = field

sql = cls._partition_query(
table_name, 1, [(field_to_return, True)], kwargs)
df = database.get_df(sql, schema)
if df.empty:
return ''
return df.to_dict()[field_to_return][0]


class HiveEngineSpec(PrestoEngineSpec):

"""Reuses PrestoEngineSpec functionality."""

engine = 'hive'
cursor_execute_kwargs = {'async': True}

@classmethod
def patch(cls):
from pyhive import hive
from superset.db_engines import hive as patched_hive
from pythrifthiveapi.TCLIService import (
constants as patched_constants,
ttypes as patched_ttypes,
TCLIService as patched_TCLIService)

hive.TCLIService = patched_TCLIService
hive.constants = patched_constants
hive.ttypes = patched_ttypes
hive.Cursor.fetch_logs = patched_hive.fetch_logs

@classmethod
@cache_util.memoized_func(
timeout=600,
key=lambda *args, **kwargs: 'db:{}:{}'.format(args[0].id, args[1]))
def fetch_result_sets(cls, db, datasource_type, force=False):
return BaseEngineSpec.fetch_result_sets(
db, datasource_type, force=force)

@classmethod
def progress(cls, logs):
# 17/02/07 19:36:38 INFO ql.Driver: Total jobs = 5
jobs_stats_r = re.compile(
r'.*INFO.*Total jobs = (?P<max_jobs>[0-9]+)')
# 17/02/07 19:37:08 INFO ql.Driver: Launching Job 2 out of 5
launching_job_r = re.compile(
'.*INFO.*Launching Job (?P<job_number>[0-9]+) out of '
'(?P<max_jobs>[0-9]+)')
# 17/02/07 19:36:58 INFO exec.Task: 2017-02-07 19:36:58,152 Stage-18
# map = 0%, reduce = 0%
stage_progress = re.compile(
r'.*INFO.*Stage-(?P<stage_number>[0-9]+).*'
r'map = (?P<map_progress>[0-9]+)%.*'
r'reduce = (?P<reduce_progress>[0-9]+)%.*')
total_jobs = None
current_job = None
stages = {}
lines = logs.splitlines()
for line in lines:
match = jobs_stats_r.match(line)
if match:
total_jobs = int(match.groupdict()['max_jobs'])
match = launching_job_r.match(line)
if match:
current_job = int(match.groupdict()['job_number'])
stages = {}
match = stage_progress.match(line)
if match:
stage_number = int(match.groupdict()['stage_number'])
map_progress = int(match.groupdict()['map_progress'])
reduce_progress = int(match.groupdict()['reduce_progress'])
stages[stage_number] = (map_progress + reduce_progress) / 2

if not total_jobs or not current_job:
return 0
stage_progress = sum(
stages.values()) / len(stages.values()) if stages else 0

progress = (
100 * (current_job - 1) / total_jobs + stage_progress / total_jobs
)
return int(progress)

@classmethod
def handle_cursor(cls, cursor, query, session):
"""Updates progress information"""
from pyhive import hive
print("PATCHED TCLIService {}".format(hive.TCLIService.__file__))
unfinished_states = (
hive.ttypes.TOperationState.INITIALIZED_STATE,
hive.ttypes.TOperationState.RUNNING_STATE,
)
polled = cursor.poll()
while polled.operationState in unfinished_states:
resp = cursor.fetch_logs()
if resp and resp.log:
progress = cls.progress(resp.log)
if progress > query.progress:
query.progress = progress
session.commit()
time.sleep(5)
polled = cursor.poll()

@classmethod
def where_latest_partition(
cls, table_name, schema, database, qry, columns=None):
try:
col_name, value = cls.latest_partition(
table_name, schema, database)
except Exception:
# table is not partitioned
return False
for c in columns:
if str(c.name) == str(col_name):
return qry.where(c == str(value))
return False

@classmethod
def latest_sub_partition(cls, table_name, **kwargs):
# TODO(bogdan): implement`
pass

@classmethod
def _latest_partition_from_df(cls, df):
"""Hive partitions look like ds={partition name}"""
return df.ix[:, 0].max().split('=')[1]

@classmethod
def _partition_query(
cls, table_name, limit=0, order_by=None, filters=None):
return "SHOW PARTITIONS {table_name}".format(**locals())


class MssqlEngineSpec(BaseEngineSpec):
engine = 'mssql'
Expand Down
Empty file added superset/db_engines/__init__.py
Empty file.
Loading

0 comments on commit 9114d86

Please sign in to comment.