Skip to content

Commit

Permalink
Feature qubole hook support headers (apache#15683)
Browse files Browse the repository at this point in the history
* Added support for including headers in qubole results

* Added missing issue number

* Fixed pylint errors and warnings

* fixing formatting that caused error with statics checks

* Apperently qubole is using strings as true false

* Adding Unit Tests for include_headers flag

* Fixing static-check pylint errors

* Added better unit tests for Quboles Hook and Operator

* fixed typo in pylint comment

* Fixing failed static checks and lint errors

* fixed some more lint issues
  • Loading branch information
levyitay authored May 11, 2021
1 parent cbc3cb8 commit 996965a
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 6 deletions.
3 changes: 2 additions & 1 deletion airflow/providers/qubole/hooks/qubole.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,9 @@ def get_results(
cmd_id = ti.xcom_pull(key="qbol_cmd_id", task_ids=self.task_id)
self.cmd = self.cls.find(cmd_id)

include_headers_str = 'true' if include_headers else 'false'
self.cmd.get_results(
fp, inline, delim, fetch, arguments=[include_headers]
fp, inline, delim, fetch, arguments=[include_headers_str]
) # type: ignore[attr-defined]
fp.flush()
fp.close()
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/qubole/hooks/qubole_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def get_query_results(self) -> Optional[str]:
cmd_id = self.cmd.id
self.log.info("command id: %d", cmd_id)
query_result_buffer = StringIO()
self.cmd.get_results(fp=query_result_buffer, inline=True, delim=COL_DELIM, arguments=[True])
self.cmd.get_results(fp=query_result_buffer, inline=True, delim=COL_DELIM, arguments=['true'])
query_result = query_result_buffer.getvalue()
query_result_buffer.close()
return query_result
Expand Down
44 changes: 42 additions & 2 deletions tests/providers/qubole/hooks/test_qubole.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,29 @@
# specific language governing permissions and limitations
# under the License.
#
import unittest
from unittest import TestCase, mock

from qds_sdk.commands import PrestoCommand

from airflow.providers.qubole.hooks.qubole import QuboleHook

DAG_ID = "qubole_test_dag"
TASK_ID = "test_task"
RESULTS_WITH_HEADER = 'header1\theader2\nval1\tval2'
RESULTS_WITH_NO_HEADER = 'val1\tval2'

add_tags = QuboleHook._add_tags


class TestQuboleHook(unittest.TestCase):
# pylint: disable = unused-argument
def get_result_mock(fp, inline, delim, fetch, arguments):
if arguments[0] == 'true':
fp.write(bytearray(RESULTS_WITH_HEADER, 'utf-8'))
else:
fp.write(bytearray(RESULTS_WITH_NO_HEADER, 'utf-8'))


class TestQuboleHook(TestCase):
def test_add_string_to_tags(self):
tags = {'dag_id', 'task_id'}
add_tags(tags, 'string')
Expand All @@ -38,3 +53,28 @@ def test_add_tuple_to_tags(self):
tags = {'dag_id', 'task_id'}
add_tags(tags, ('value1', 'value2'))
assert {'dag_id', 'task_id', 'value1', 'value2'} == tags

@mock.patch('qds_sdk.commands.Command.get_results', new=get_result_mock)
def test_get_results_with_headers(self):
dag = mock.MagicMock()
dag.dag_id = DAG_ID
hook = QuboleHook(task_id=TASK_ID, command_type='prestocmd', dag=dag)

task = mock.MagicMock()
task.xcom_pull.return_value = 'test_command_id'
with mock.patch('qds_sdk.resource.Resource.find', return_value=PrestoCommand):
results = open(hook.get_results(ti=task, include_headers=True)).read()
assert results == RESULTS_WITH_HEADER

@mock.patch('qds_sdk.commands.Command.get_results', new=get_result_mock)
def test_get_results_without_headers(self):
dag = mock.MagicMock()
dag.dag_id = DAG_ID
hook = QuboleHook(task_id=TASK_ID, command_type='prestocmd', dag=dag)

task = mock.MagicMock()
task.xcom_pull.return_value = 'test_command_id'

with mock.patch('qds_sdk.resource.Resource.find', return_value=PrestoCommand):
results = open(hook.get_results(ti=task, include_headers=False)).read()
assert results == RESULTS_WITH_NO_HEADER
18 changes: 16 additions & 2 deletions tests/providers/qubole/operators/test_qubole.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# under the License.
#

import unittest
from unittest import TestCase, mock

from airflow import settings
from airflow.models import DAG, Connection
Expand All @@ -36,7 +36,7 @@
DEFAULT_DATE = datetime(2017, 1, 1)


class TestQuboleOperator(unittest.TestCase):
class TestQuboleOperator(TestCase):
def setUp(self):
db.merge_conn(Connection(conn_id=DEFAULT_CONN, conn_type='HTTP'))
db.merge_conn(Connection(conn_id=TEST_CONN, conn_type='HTTP', host='http://localhost/api'))
Expand Down Expand Up @@ -180,3 +180,17 @@ def test_parameter_pool_passed(self):
test_pool = 'test_pool'
op = QuboleOperator(task_id=TASK_ID, pool=test_pool)
assert op.pool == test_pool

@mock.patch('airflow.providers.qubole.hooks.qubole.QuboleHook.get_results')
def test_parameter_include_header_passed(self, mock_get_results):
dag = DAG(DAG_ID, start_date=DEFAULT_DATE)
qubole_operator = QuboleOperator(task_id=TASK_ID, dag=dag, command_type='prestocmd')
qubole_operator.get_results(include_headers=True)
mock_get_results.asset_called_with('include_headers', True)

@mock.patch('airflow.providers.qubole.hooks.qubole.QuboleHook.get_results')
def test_parameter_include_header_missing(self, mock_get_results):
dag = DAG(DAG_ID, start_date=DEFAULT_DATE)
qubole_operator = QuboleOperator(task_id=TASK_ID, dag=dag, command_type='prestocmd')
qubole_operator.get_results()
mock_get_results.asset_called_with('include_headers', False)

0 comments on commit 996965a

Please sign in to comment.