Skip to content

Commit

Permalink
[FEAT] databricks repair run with reason match and appropriate new se…
Browse files Browse the repository at this point in the history
…ttings (apache#41412)
  • Loading branch information
gaurav7261 authored Aug 30, 2024
1 parent 6e01118 commit 365b42f
Show file tree
Hide file tree
Showing 4 changed files with 244 additions and 4 deletions.
10 changes: 10 additions & 0 deletions airflow/providers/databricks/hooks/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@

CREATE_ENDPOINT = ("POST", "api/2.1/jobs/create")
RESET_ENDPOINT = ("POST", "api/2.1/jobs/reset")
UPDATE_ENDPOINT = ("POST", "api/2.1/jobs/update")
RUN_NOW_ENDPOINT = ("POST", "api/2.1/jobs/run-now")
SUBMIT_RUN_ENDPOINT = ("POST", "api/2.1/jobs/runs/submit")
GET_RUN_ENDPOINT = ("GET", "api/2.1/jobs/runs/get")
Expand Down Expand Up @@ -233,6 +234,15 @@ def reset_job(self, job_id: str, json: dict) -> None:
"""
self._do_api_call(RESET_ENDPOINT, {"job_id": job_id, "new_settings": json})

def update_job(self, job_id: str, json: dict) -> None:
"""
Call the ``api/2.1/jobs/update`` endpoint.
:param job_id: The id of the job to update.
:param json: The data used in the new_settings of the request to the ``update`` endpoint.
"""
self._do_api_call(UPDATE_ENDPOINT, {"job_id": job_id, "new_settings": json})

def run_now(self, json: dict) -> int:
"""
Call the ``api/2.1/jobs/run-now`` endpoint.
Expand Down
67 changes: 64 additions & 3 deletions airflow/providers/databricks/operators/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,23 @@ def _handle_databricks_operator_execution(operator, hook, log, context) -> None:
f"and with the error {run_state.state_message}"
)

if isinstance(operator, DatabricksRunNowOperator) and operator.repair_run:
should_repair = (
isinstance(operator, DatabricksRunNowOperator)
and operator.repair_run
and (
not operator.databricks_repair_reason_new_settings
or is_repair_reason_match_exist(operator, run_state)
)
)

if should_repair:
operator.repair_run = False
log.warning(
"%s but since repair run is set, repairing the run with all failed tasks",
error_message,
)

job_id = operator.json["job_id"]
update_job_for_repair(operator, hook, job_id, run_state)
latest_repair_id = hook.get_latest_repair_id(operator.run_id)
repair_json = {"run_id": operator.run_id, "rerun_all_failed_tasks": True}
if latest_repair_id is not None:
Expand All @@ -123,6 +133,41 @@ def _handle_databricks_operator_execution(operator, hook, log, context) -> None:
log.info("View run status, Spark UI, and logs at %s", run_page_url)


def is_repair_reason_match_exist(operator: Any, run_state: RunState) -> bool:
"""
Check if the repair reason matches the run state message.
:param operator: Databricks operator being handled
:param run_state: Run state of the Databricks job
:return: True if repair reason matches the run state message, False otherwise
"""
return any(reason in run_state.state_message for reason in operator.databricks_repair_reason_new_settings)


def update_job_for_repair(operator: Any, hook: Any, job_id: int, run_state: RunState) -> None:
"""
Update job settings(partial) to repair the run with all failed tasks.
:param operator: Databricks operator being handled
:param hook: Databricks hook
:param job_id: Job ID of Databricks
:param run_state: Run state of the Databricks job
"""
repair_reason = next(
(
reason
for reason in operator.databricks_repair_reason_new_settings
if reason in run_state.state_message
),
None,
)
if repair_reason is not None:
new_settings_json = normalise_json_content(
operator.databricks_repair_reason_new_settings[repair_reason]
)
hook.update_job(job_id=job_id, json=new_settings_json)


def _handle_deferrable_databricks_operator_execution(operator, hook, log, context) -> None:
"""
Handle the Airflow + Databricks lifecycle logic for deferrable Databricks operators.
Expand Down Expand Up @@ -674,6 +719,7 @@ class DatabricksRunNowOperator(BaseOperator):
- ``spark_submit_params``
- ``idempotency_token``
- ``repair_run``
- ``databricks_repair_reason_new_settings``
- ``cancel_previous_runs``
:param job_id: the job_id of the existing Databricks job.
Expand Down Expand Up @@ -764,6 +810,12 @@ class DatabricksRunNowOperator(BaseOperator):
:param wait_for_termination: if we should wait for termination of the job run. ``True`` by default.
:param deferrable: Run operator in the deferrable mode.
:param repair_run: Repair the databricks run in case of failure.
:param databricks_repair_reason_new_settings: A dict of reason and new_settings JSON object for which
to repair the run. `None` by default. `None` means to repair at all cases with existing job
settings otherwise check whether `RunState` state_message contains reason and
update job settings as per new_settings using databricks partial job update endpoint
(https://docs.databricks.com/api/workspace/jobs/update). If nothing is matched, then repair
will not get triggered.
:param cancel_previous_runs: Cancel all existing running jobs before submitting new one.
"""

Expand Down Expand Up @@ -796,6 +848,7 @@ def __init__(
wait_for_termination: bool = True,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
repair_run: bool = False,
databricks_repair_reason_new_settings: dict[str, Any] | None = None,
cancel_previous_runs: bool = False,
**kwargs,
) -> None:
Expand All @@ -810,6 +863,7 @@ def __init__(
self.wait_for_termination = wait_for_termination
self.deferrable = deferrable
self.repair_run = repair_run
self.databricks_repair_reason_new_settings = databricks_repair_reason_new_settings or {}
self.cancel_previous_runs = cancel_previous_runs

if job_id is not None:
Expand Down Expand Up @@ -870,9 +924,16 @@ def execute(self, context: Context):
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
if event:
_handle_deferrable_databricks_operator_completion(event, self.log)
if event["repair_run"]:
run_state = RunState.from_json(event["run_state"])
should_repair = event["repair_run"] and (
not self.databricks_repair_reason_new_settings
or is_repair_reason_match_exist(self, run_state)
)
if should_repair:
self.repair_run = False
self.run_id = event["run_id"]
job_id = self._hook.get_job_id(self.run_id)
update_job_for_repair(self, self._hook, job_id, run_state)
latest_repair_id = self._hook.get_latest_repair_id(self.run_id)
repair_json = {"run_id": self.run_id, "rerun_all_failed_tasks": True}
if latest_repair_id is not None:
Expand Down
24 changes: 24 additions & 0 deletions tests/providers/databricks/hooks/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,13 @@ def reset_endpoint(host):
return f"https://{host}/api/2.1/jobs/reset"


def update_endpoint(host):
"""
Utility function to generate the update endpoint given the host.
"""
return f"https://{host}/api/2.1/jobs/update"


def run_now_endpoint(host):
"""
Utility function to generate the run now endpoint given the host.
Expand Down Expand Up @@ -464,6 +471,23 @@ def test_reset(self, mock_requests):
timeout=self.hook.timeout_seconds,
)

@mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
def test_update(self, mock_requests):
mock_requests.codes.ok = 200
status_code_mock = mock.PropertyMock(return_value=200)
type(mock_requests.post.return_value).status_code = status_code_mock
json = {"name": "test"}
self.hook.update_job(JOB_ID, json)

mock_requests.post.assert_called_once_with(
update_endpoint(HOST),
json={"job_id": JOB_ID, "new_settings": {"name": "test"}},
params=None,
auth=HTTPBasicAuth(LOGIN, PASSWORD),
headers=self.hook.user_agent_header,
timeout=self.hook.timeout_seconds,
)

@mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
def test_submit_run(self, mock_requests):
mock_requests.post.return_value.json.return_value = {"run_id": "1"}
Expand Down
147 changes: 146 additions & 1 deletion tests/providers/databricks/operators/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

from datetime import datetime, timedelta
from typing import Any
from unittest import mock
from unittest.mock import MagicMock

Expand Down Expand Up @@ -152,7 +153,7 @@
"retry_on_timeout": False,
},
]
JOB_CLUSTERS = [
JOB_CLUSTERS: list[dict[str, Any]] = [
{
"job_cluster_key": "auto_scaling_cluster",
"new_cluster": {
Expand All @@ -172,6 +173,45 @@
},
},
]

JOB_CLUSTERS_REPAIR_AWS_INSUFFICIENT_INSTANCE_CAPACITY_FAILURE: list[dict[str, Any]] = [
{
**cluster,
"new_cluster": {
**cluster["new_cluster"],
"aws_attributes": {
**cluster["new_cluster"]["aws_attributes"],
"zone_id": "us-east-2a",
},
},
}
for cluster in JOB_CLUSTERS
]

JOB_CLUSTERS_REPAIR_AWS_MAX_SPOT_INSTANCE_COUNT_EXCEEDED_FAILURE: list[dict[str, Any]] = [
{
**cluster,
"new_cluster": {
**cluster["new_cluster"],
"aws_attributes": {
**cluster["new_cluster"]["aws_attributes"],
"availability": "ON_DEMAND",
},
},
}
for cluster in JOB_CLUSTERS
]


DATABRICKS_REPAIR_REASON_NEW_SETTINGS = {
"AWS_INSUFFICIENT_INSTANCE_CAPACITY_FAILURE": {
"job_clusters": JOB_CLUSTERS_REPAIR_AWS_INSUFFICIENT_INSTANCE_CAPACITY_FAILURE,
},
"AWS_MAX_SPOT_INSTANCE_COUNT_EXCEEDED_FAILURE": {
"job_clusters": JOB_CLUSTERS_REPAIR_AWS_MAX_SPOT_INSTANCE_COUNT_EXCEEDED_FAILURE,
},
}

EMAIL_NOTIFICATIONS = {
"on_start": [
"[email protected]",
Expand Down Expand Up @@ -1723,6 +1763,111 @@ def test_execute_complete_failure_and_repair_run(
db_mock.repair_run.assert_called_once()
mock_handle_deferrable_databricks_operator_execution.assert_called_once()

@mock.patch(
"airflow.providers.databricks.operators.databricks._handle_deferrable_databricks_operator_execution"
)
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_deferrable_exec_with_databricks_repair_reason_new_settings(
self, db_mock_class, mock_handle_deferrable_databricks_operator_execution
):
"""
Test the deferrable execute function in case where user want to repair with new settings
"""
state_message = f"""Task {TASK_ID} failed with message: Cluster {EXISTING_CLUSTER_ID} was terminated.
Reason: AWS_INSUFFICIENT_INSTANCE_CAPACITY_FAILURE (CLIENT_ERROR).
Parameters: aws_api_error_code:InsufficientInstanceCapacity, aws_error_message:There is no Spot
capacity available that matches your request..
"""
run_state_failed = RunState(
"TERMINATED",
"FAILED",
state_message,
)
run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS}
event = {
"run_id": RUN_ID,
"run_page_url": RUN_PAGE_URL,
"run_state": run_state_failed.to_json(),
"repair_run": True,
"errors": [],
}

op = DatabricksRunNowOperator(
deferrable=True,
task_id=TASK_ID,
job_id=JOB_ID,
json=run,
databricks_repair_reason_new_settings=DATABRICKS_REPAIR_REASON_NEW_SETTINGS,
)
db_mock = db_mock_class.return_value
db_mock.run_now.return_value = RUN_ID
db_mock.get_job_id.return_value = JOB_ID
db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED", state_message)

op.execute_complete(context=None, event=event)

db_mock.update_job.assert_called_once()
db_mock.update_job.assert_called_with(
job_id=JOB_ID,
json=utils.normalise_json_content(
DATABRICKS_REPAIR_REASON_NEW_SETTINGS["AWS_INSUFFICIENT_INSTANCE_CAPACITY_FAILURE"]
),
)
db_mock.repair_run.assert_called_once()
mock_handle_deferrable_databricks_operator_execution.assert_called_once()

@mock.patch("airflow.providers.databricks.operators.databricks.is_repair_reason_match_exist")
@mock.patch(
"airflow.providers.databricks.operators.databricks._handle_deferrable_databricks_operator_execution"
)
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_deferrable_exec_with_none_databricks_repair_reason_new_settings(
self,
db_mock_class,
mock_handle_deferrable_databricks_operator_execution,
mock_handle_is_repair_reason_match_exist,
):
"""
Test the deferrable execute function where user does not want to repair with new settings
"""
state_message = f"""Task {TASK_ID} failed with message: Cluster {EXISTING_CLUSTER_ID} was terminated.
Reason: AWS_INSUFFICIENT_INSTANCE_CAPACITY_FAILURE (CLIENT_ERROR).
Parameters: aws_api_error_code:InsufficientInstanceCapacity, aws_error_message:There is no Spot
capacity available that matches your request..
"""
run_state_failed = RunState(
"TERMINATED",
"FAILED",
state_message,
)
run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS}
event = {
"run_id": RUN_ID,
"run_page_url": RUN_PAGE_URL,
"run_state": run_state_failed.to_json(),
"repair_run": True,
"errors": [],
}

op = DatabricksRunNowOperator(
deferrable=True,
task_id=TASK_ID,
job_id=JOB_ID,
json=run,
databricks_repair_reason_new_settings=None,
)
db_mock = db_mock_class.return_value
db_mock.run_now.return_value = RUN_ID
db_mock.get_job_id.return_value = JOB_ID
db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED", state_message)

op.execute_complete(context=None, event=event)

db_mock.update_job.assert_not_called()
db_mock.repair_run.assert_called_once()
mock_handle_deferrable_databricks_operator_execution.assert_called_once()
mock_handle_is_repair_reason_match_exist.assert_not_called()

def test_execute_complete_incorrect_event_validation_failure(self):
event = {"event_id": "no such column"}
op = DatabricksRunNowOperator(deferrable=True, task_id=TASK_ID, job_id=JOB_ID)
Expand Down

0 comments on commit 365b42f

Please sign in to comment.