Skip to content

Commit

Permalink
Add Spark's appId to xcom output (apache#27376)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdsoha authored Oct 31, 2022
1 parent 0c94eff commit f75582a
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 12 deletions.
2 changes: 2 additions & 0 deletions airflow/providers/apache/livy/operators/livy.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ def execute(self, context: Context) -> Any:
if self._polling_interval > 0:
self.poll_for_termination(self._batch_id)

context["ti"].xcom_push(key="app_id", value=self.get_hook().get_batch(self._batch_id)["appId"])

return self._batch_id

def poll_for_termination(self, batch_id: int | str) -> None:
Expand Down
29 changes: 17 additions & 12 deletions tests/providers/apache/livy/operators/test_livy.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
from airflow.utils import db, timezone

DEFAULT_DATE = timezone.datetime(2017, 1, 1)
mock_livy_client = MagicMock()

BATCH_ID = 100
APP_ID = "application_1433865536131_34483"
GET_BATCH = {"appId": APP_ID}
LOG_RESPONSE = {"total": 3, "log": ["first_line", "second_line", "third_line"]}


Expand All @@ -45,14 +45,14 @@ def setUp(self):
conn_id="livyunittest", conn_type="livy", host="localhost:8998", port="8998", schema="http"
)
)
self.mock_context = dict(ti=MagicMock())

@patch(
"airflow.providers.apache.livy.operators.livy.LivyHook.dump_batch_logs",
return_value=None,
)
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.get_batch_state")
def test_poll_for_termination(self, mock_livy, mock_dump_logs):

state_list = 2 * [BatchState.RUNNING] + [BatchState.SUCCESS]

def side_effect(_, retry_args):
Expand All @@ -77,7 +77,6 @@ def side_effect(_, retry_args):
)
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.get_batch_state")
def test_poll_for_termination_fail(self, mock_livy, mock_dump_logs):

state_list = 2 * [BatchState.RUNNING] + [BatchState.ERROR]

def side_effect(_, retry_args):
Expand Down Expand Up @@ -107,39 +106,44 @@ def side_effect(_, retry_args):
return_value=BatchState.SUCCESS,
)
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID)
def test_execution(self, mock_post, mock_get, mock_dump_logs):
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.get_batch", return_value=GET_BATCH)
def test_execution(self, mock_get_batch, mock_post, mock_get, mock_dump_logs):
task = LivyOperator(
livy_conn_id="livyunittest",
file="sparkapp",
polling_interval=1,
dag=self.dag,
task_id="livy_example",
)
task.execute(context={})
task.execute(context=self.mock_context)

call_args = {k: v for k, v in mock_post.call_args[1].items() if v}
assert call_args == {"file": "sparkapp"}
mock_get.assert_called_once_with(BATCH_ID, retry_args=None)
mock_dump_logs.assert_called_once_with(BATCH_ID)
mock_get_batch.assert_called_once_with(BATCH_ID)
self.mock_context["ti"].xcom_push.assert_called_once_with(key="app_id", value=APP_ID)

@patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch")
def test_execution_with_extra_options(self, mock_post):
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.get_batch", return_value=GET_BATCH)
def test_execution_with_extra_options(self, mock_get_batch, mock_post):
extra_options = {"check_response": True}
task = LivyOperator(
file="sparkapp", dag=self.dag, task_id="livy_example", extra_options=extra_options
)

task.execute(context={})
task.execute(context=self.mock_context)

assert task.get_hook().extra_options == extra_options

@patch("airflow.providers.apache.livy.operators.livy.LivyHook.delete_batch")
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID)
def test_deletion(self, mock_post, mock_delete):
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.get_batch", return_value=GET_BATCH)
def test_deletion(self, mock_get_batch, mock_post, mock_delete):
task = LivyOperator(
livy_conn_id="livyunittest", file="sparkapp", dag=self.dag, task_id="livy_example"
)
task.execute(context={})
task.execute(context=self.mock_context)
task.kill()

mock_delete.assert_called_once_with(BATCH_ID)
Expand All @@ -158,7 +162,8 @@ def test_injected_hook(self):
)
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.get_batch_logs", return_value=LOG_RESPONSE)
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID)
def test_log_dump(self, mock_post, mock_get_logs, mock_get):
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.get_batch", return_value=GET_BATCH)
def test_log_dump(self, mock_get_batch, mock_post, mock_get_logs, mock_get):
task = LivyOperator(
livy_conn_id="livyunittest",
file="sparkapp",
Expand All @@ -167,7 +172,7 @@ def test_log_dump(self, mock_post, mock_get_logs, mock_get):
polling_interval=1,
)
with self.assertLogs(task.get_hook().log, level=logging.INFO) as cm:
task.execute(context={})
task.execute(context=self.mock_context)
assert "INFO:airflow.providers.apache.livy.hooks.livy.LivyHook:first_line" in cm.output
assert "INFO:airflow.providers.apache.livy.hooks.livy.LivyHook:second_line" in cm.output
assert "INFO:airflow.providers.apache.livy.hooks.livy.LivyHook:third_line" in cm.output
Expand Down

0 comments on commit f75582a

Please sign in to comment.