Skip to content

Commit

Permalink
Adding configurable fetch_all_handler for JdbcOperator (apache#25412)
Browse files Browse the repository at this point in the history
  • Loading branch information
kazanzhy authored Aug 7, 2022
1 parent 3dfa445 commit 1708da9
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
9 changes: 7 additions & 2 deletions airflow/providers/jdbc/operators/jdbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.

from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union
from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, Optional, Sequence, Union

from airflow.models import BaseOperator
from airflow.providers.common.sql.hooks.sql import fetch_all_handler
Expand Down Expand Up @@ -57,16 +57,21 @@ def __init__(
jdbc_conn_id: str = 'jdbc_default',
autocommit: bool = False,
parameters: Optional[Union[Iterable, Mapping]] = None,
handler: Callable[[Any], Any] = fetch_all_handler,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.parameters = parameters
self.sql = sql
self.jdbc_conn_id = jdbc_conn_id
self.autocommit = autocommit
self.handler = handler
self.hook = None

def execute(self, context: 'Context'):
self.log.info('Executing: %s', self.sql)
hook = JdbcHook(jdbc_conn_id=self.jdbc_conn_id)
return hook.run(self.sql, self.autocommit, parameters=self.parameters, handler=fetch_all_handler)
if self.do_xcom_push:
return hook.run(self.sql, self.autocommit, parameters=self.parameters, handler=self.handler)
else:
return hook.run(self.sql, self.autocommit, parameters=self.parameters)
16 changes: 14 additions & 2 deletions tests/providers/jdbc/operators/test_jdbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def setUp(self):
self.kwargs = dict(sql='sql', task_id='test_jdbc_operator', dag=None)

@patch('airflow.providers.jdbc.operators.jdbc.JdbcHook')
def test_execute(self, mock_jdbc_hook):
jdbc_operator = JdbcOperator(**self.kwargs)
def test_execute_do_push(self, mock_jdbc_hook):
jdbc_operator = JdbcOperator(**self.kwargs, do_xcom_push=True)
jdbc_operator.execute(context={})

mock_jdbc_hook.assert_called_once_with(jdbc_conn_id=jdbc_operator.jdbc_conn_id)
Expand All @@ -39,3 +39,15 @@ def test_execute(self, mock_jdbc_hook):
parameters=jdbc_operator.parameters,
handler=fetch_all_handler,
)

@patch('airflow.providers.jdbc.operators.jdbc.JdbcHook')
def test_execute_dont_push(self, mock_jdbc_hook):
jdbc_operator = JdbcOperator(**self.kwargs, do_xcom_push=False)
jdbc_operator.execute(context={})

mock_jdbc_hook.assert_called_once_with(jdbc_conn_id=jdbc_operator.jdbc_conn_id)
mock_jdbc_hook.return_value.run.assert_called_once_with(
jdbc_operator.sql,
jdbc_operator.autocommit,
parameters=jdbc_operator.parameters,
)

0 comments on commit 1708da9

Please sign in to comment.