Skip to content

Commit

Permalink
Ensure that zombie tasks for dags with errors get cleaned up (apache#…
Browse files Browse the repository at this point in the history
…25550)

If there is a parse error in a DAG the zombie cleanup request never ran,
which resulted in the TI never leaving running state and just
continually being detected as a zombie.

(Prior to AIP-45 landing, this bug/behaviour resulted in a DAG with a
parse error never actually leaving the queued state.)

The fix here is to _always_ make sure we run `ti.handle_failure` when we
are given a request, even if we can't load the DAG. To _try_ and work as
well as we can, we try to load the serialized_dag if we can, but in
cases where we can't for whatever reason we also make sure
TaskInstance.handle_failure is able to operate even when `self.task` is
None.
  • Loading branch information
ashb authored Aug 5, 2022
1 parent d4f560b commit 1d8507a
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 27 deletions.
93 changes: 78 additions & 15 deletions airflow/dag_processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
from contextlib import redirect_stderr, redirect_stdout, suppress
from datetime import timedelta
from multiprocessing.connection import Connection as MultiprocessingConnection
from typing import Iterator, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Iterator, List, Optional, Set, Tuple

from setproctitle import setproctitle
from sqlalchemy import func, or_
from sqlalchemy import exc, func, or_
from sqlalchemy.orm.session import Session

from airflow import models, settings
Expand All @@ -52,6 +52,9 @@
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import State

if TYPE_CHECKING:
from airflow.models.operator import Operator

DR = models.DagRun
TI = models.TaskInstance

Expand Down Expand Up @@ -625,7 +628,7 @@ def execute_callbacks(
self.log.debug("Processing Callback Request: %s", request)
try:
if isinstance(request, TaskCallbackRequest):
self._execute_task_callbacks(dagbag, request)
self._execute_task_callbacks(dagbag, request, session=session)
elif isinstance(request, SlaCallbackRequest):
self.manage_slas(dagbag.get_dag(request.dag_id), session=session)
elif isinstance(request, DagCallbackRequest):
Expand All @@ -637,7 +640,27 @@ def execute_callbacks(
request.full_filepath,
)

session.commit()
session.flush()

def execute_callbacks_without_dag(
self, callback_requests: List[CallbackRequest], session: Session
) -> None:
"""
Execute what callbacks we can as "best effort" when the dag cannot be found/had parse errors.
This is so important so that tasks that failed when there is a parse
error don't get stuck in queued state.
"""
for request in callback_requests:
self.log.debug("Processing Callback Request: %s", request)
if isinstance(request, TaskCallbackRequest):
self._execute_task_callbacks(None, request, session)
else:
self.log.info(
"Not executing %s callback for file %s as there was a dag parse error",
request.__class__.__name__,
request.full_filepath,
)

@provide_session
def _execute_dag_callbacks(self, dagbag: DagBag, request: DagCallbackRequest, session: Session):
Expand All @@ -647,26 +670,59 @@ def _execute_dag_callbacks(self, dagbag: DagBag, request: DagCallbackRequest, se
dagrun=dag_run, success=not request.is_failure_callback, reason=request.msg, session=session
)

def _execute_task_callbacks(self, dagbag: DagBag, request: TaskCallbackRequest):
def _execute_task_callbacks(
self, dagbag: Optional[DagBag], request: TaskCallbackRequest, session: Session
):
if not request.is_failure_callback:
return

simple_ti = request.simple_task_instance
if simple_ti.dag_id in dagbag.dags:
ti: Optional[TI] = (
session.query(TI)
.filter_by(
dag_id=simple_ti.dag_id,
run_id=simple_ti.run_id,
task_id=simple_ti.task_id,
map_index=simple_ti.map_index,
)
.one_or_none()
)
if not ti:
return

task: Optional["Operator"] = None

if dagbag and simple_ti.dag_id in dagbag.dags:
dag = dagbag.dags[simple_ti.dag_id]
if simple_ti.task_id in dag.task_ids:
task = dag.get_task(simple_ti.task_id)
if request.is_failure_callback:
ti = TI(task, run_id=simple_ti.run_id, map_index=simple_ti.map_index)
# TODO: Use simple_ti to improve performance here in the future
ti.refresh_from_db()
ti.handle_failure(error=request.msg, test_mode=self.UNIT_TEST_MODE)
self.log.info('Executed failure callback for %s in state %s', ti, ti.state)
else:
# We don't have the _real_ dag here (perhaps it had a parse error?) but we still want to run
# `handle_failure` so that the state of the TI gets progressed.
#
# Since handle_failure _really_ wants a task, we do our best effort to give it one
from airflow.models.serialized_dag import SerializedDagModel

try:
model = session.query(SerializedDagModel).get(simple_ti.dag_id)
if model:
task = model.dag.get_task(simple_ti.task_id)
except (exc.NoResultFound, TaskNotFound):
pass
if task:
ti.refresh_from_task(task)

ti.handle_failure(error=request.msg, test_mode=self.UNIT_TEST_MODE, session=session)
self.log.info('Executed failure callback for %s in state %s', ti, ti.state)
session.flush()

@provide_session
def process_file(
self,
file_path: str,
callback_requests: List[CallbackRequest],
pickle_dags: bool = False,
session: Session = None,
session: Session = NEW_SESSION,
) -> Tuple[int, int]:
"""
Process a Python file containing Airflow DAGs.
Expand Down Expand Up @@ -702,12 +758,19 @@ def process_file(
else:
self.log.warning("No viable dags retrieved from %s", file_path)
self.update_import_errors(session, dagbag)
if callback_requests:
# If there were callback requests for this file but there was a
# parse error we still need to progress the state of TIs,
# otherwise they might be stuck in queued/running for ever!
self.execute_callbacks_without_dag(callback_requests, session)
return 0, len(dagbag.import_errors)

self.execute_callbacks(dagbag, callback_requests)
self.execute_callbacks(dagbag, callback_requests, session)
session.commit()

# Save individual DAGs in the ORM
dagbag.sync_to_db()
dagbag.sync_to_db(session)
session.commit()

if pickle_dags:
paused_dag_ids = DagModel.get_paused_dag_ids(dag_ids=dagbag.dag_ids)
Expand Down
3 changes: 2 additions & 1 deletion airflow/models/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def __init__(self, event, task_instance=None, owner=None, extra=None, **kwargs):
self.task_id = task_instance.task_id
self.execution_date = task_instance.execution_date
self.map_index = task_instance.map_index
task_owner = task_instance.task.owner
if task_instance.task:
task_owner = task_instance.task.owner

if 'task_id' in kwargs:
self.task_id = kwargs['task_id']
Expand Down
17 changes: 11 additions & 6 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1845,9 +1845,6 @@ def handle_failure(
if test_mode is None:
test_mode = self.test_mode

if context is None:
context = self.get_template_context()

if error:
if isinstance(error, BaseException):
tb = self.get_truncated_error_traceback(error, truncate_to=self._execute_task)
Expand All @@ -1859,7 +1856,7 @@ def handle_failure(

self.end_date = timezone.utcnow()
self.set_duration()
Stats.incr(f'operator_failures_{self.task.task_type}')
Stats.incr(f'operator_failures_{self.operator}')
Stats.incr('ti_failures')
if not test_mode:
session.add(Log(State.FAILED, self))
Expand All @@ -1869,6 +1866,10 @@ def handle_failure(

self.clear_next_method_args()

# In extreme cases (zombie in case of dag with parse error) we might _not_ have a Task.
if context is None and self.task:
context = self.get_template_context(session)

if context is not None:
context['exception'] = error

Expand All @@ -1886,7 +1887,8 @@ def handle_failure(

task: Optional[BaseOperator] = None
try:
task = self.task.unmap((context, session))
if self.task and context:
task = self.task.unmap((context, session))
except Exception:
self.log.error("Unable to unmap task to determine if we need to send an alert email")

Expand All @@ -1911,7 +1913,7 @@ def handle_failure(
except Exception:
self.log.exception('Failed to send email to: %s', task.email)

if callback:
if callback and context:
self._run_finished_callback(callback, context, callback_type)

if not test_mode:
Expand All @@ -1924,6 +1926,9 @@ def is_eligible_to_retry(self):
# If a task is cleared when running, it goes into RESTARTING state and is always
# eligible for retry
return True
if not self.task:
# Couldn't load the task, don't know number of retries, guess:
return self.try_number <= self.max_tries

return self.task.retries and self.try_number <= self.max_tries

Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ def cleanup(self):
if not dag_ids:
return
# To isolate problems here with problems from elsewhere on the session object
self.session.flush()
self.session.rollback()

self.session.query(SerializedDagModel).filter(
SerializedDagModel.dag_id.in_(dag_ids)
Expand Down
41 changes: 38 additions & 3 deletions tests/dag_processing/test_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from airflow.dag_processing.manager import DagFileProcessorAgent
from airflow.dag_processing.processor import DagFileProcessor
from airflow.models import DagBag, DagModel, SlaMiss, TaskInstance, errors
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.taskinstance import SimpleTaskInstance
from airflow.operators.empty import EmptyOperator
from airflow.utils import timezone
Expand Down Expand Up @@ -388,10 +389,44 @@ def test_execute_on_failure_callbacks(self, mock_ti_handle_failure):
full_filepath="A", simple_task_instance=SimpleTaskInstance.from_ti(ti), msg="Message"
)
]
dag_file_processor.execute_callbacks(dagbag, requests)
dag_file_processor.execute_callbacks(dagbag, requests, session)
mock_ti_handle_failure.assert_called_once_with(
error="Message", test_mode=conf.getboolean('core', 'unit_test_mode'), session=session
)

@pytest.mark.parametrize(
["has_serialized_dag"],
[pytest.param(True, id="dag_in_db"), pytest.param(False, id="no_dag_found")],
)
@patch.object(TaskInstance, 'handle_failure')
def test_execute_on_failure_callbacks_without_dag(self, mock_ti_handle_failure, has_serialized_dag):
dagbag = DagBag(dag_folder="/dev/null", include_examples=True, read_dags_from_db=False)
dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
with create_session() as session:
session.query(TaskInstance).delete()
dag = dagbag.get_dag('example_branch_operator')
dagrun = dag.create_dagrun(
state=State.RUNNING,
execution_date=DEFAULT_DATE,
run_type=DagRunType.SCHEDULED,
session=session,
)
task = dag.get_task(task_id='run_this_first')
ti = TaskInstance(task, run_id=dagrun.run_id, state=State.QUEUED)
session.add(ti)

if has_serialized_dag:
assert SerializedDagModel.write_dag(dag, session=session) is True
session.flush()

requests = [
TaskCallbackRequest(
full_filepath="A", simple_task_instance=SimpleTaskInstance.from_ti(ti), msg="Message"
)
]
dag_file_processor.execute_callbacks_without_dag(requests, session)
mock_ti_handle_failure.assert_called_once_with(
error="Message",
test_mode=conf.getboolean('core', 'unit_test_mode'),
error="Message", test_mode=conf.getboolean('core', 'unit_test_mode'), session=session
)

def test_failure_callbacks_should_not_drop_hostname(self):
Expand Down
33 changes: 32 additions & 1 deletion tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2081,7 +2081,7 @@ def test_handle_failure_updates_queued_task_try_number(self, dag_maker):
ti = TI(task=task, run_id=dr.run_id)
ti.state = State.QUEUED
session.merge(ti)
session.commit()
session.flush()
assert ti.state == State.QUEUED
assert ti.try_number == 1
ti.handle_failure("test queued ti", test_mode=True)
Expand All @@ -2091,6 +2091,37 @@ def test_handle_failure_updates_queued_task_try_number(self, dag_maker):
# Check 'ti.try_number' is bumped to 2. This is try_number for next run
assert ti.try_number == 2

@patch.object(Stats, 'incr')
def test_handle_failure_no_task(self, Stats_incr, dag_maker):
"""
When a zombie is detected for a DAG with a parse error, we need to be able to run handle_failure
_without_ ti.task being set
"""
session = settings.Session()
with dag_maker():
task = EmptyOperator(task_id="mytask", retries=1)
dr = dag_maker.create_dagrun()
ti = TI(task=task, run_id=dr.run_id)
ti = session.merge(ti)
ti.task = None
ti.state = State.QUEUED
session.flush()

assert ti.task is None, "Check critical pre-condition"

assert ti.state == State.QUEUED
assert ti.try_number == 1

ti.handle_failure("test queued ti", test_mode=False)
assert ti.state == State.UP_FOR_RETRY
# Assert that 'ti._try_number' is bumped from 0 to 1. This is the last/current try
assert ti._try_number == 1
# Check 'ti.try_number' is bumped to 2. This is try_number for next run
assert ti.try_number == 2

Stats_incr.assert_any_call('ti_failures')
Stats_incr.assert_any_call('operator_failures_EmptyOperator')

def test_does_not_retry_on_airflow_fail_exception(self, dag_maker):
def fail():
raise AirflowFailException("hopeless")
Expand Down

0 comments on commit 1d8507a

Please sign in to comment.