Skip to content

Commit

Permalink
Don't fail DagRun when leaf mapped_task is SKIPPED (apache#25995)
Browse files Browse the repository at this point in the history
This one was a fun one to track down, and was only a problem when the
scheduler "expanded" the first mapped task (making it SKIPPED).

- The scheduler looks at `add_one` and marks it as SKIPPED.
- `unfinished_tis` contains add_one_1, and `_are_premature_tis` changes
  the state of `add_one__1`, (which just so happens to be a leaf task),
  and the check on line 584 essentially gets confused, as no one
  envisaged the states changing!

In a reverse of how this normally plays out, if the mini_scheduler in
the LocalTaskJob was disabled then this example DAG would deadlock every
time. (Since that mini scheduler only operates on a partial DAG it can't
ever change the whole DagRun state.)

Co-authored-by: Ephraim Anierobi <[email protected]>
Co-authored-by: Tzu-ping Chung <[email protected]>
  • Loading branch information
3 people authored Sep 1, 2022
1 parent 1e19807 commit 5697e9f
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 58 deletions.
2 changes: 1 addition & 1 deletion airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,7 @@ def __init__(
f"'{dag.dag_id if dag else ''}.{task_id}'; received '{trigger_rule}'."
)

self.trigger_rule = TriggerRule(trigger_rule)
self.trigger_rule: TriggerRule = TriggerRule(trigger_rule)
self.depends_on_past: bool = depends_on_past
self.ignore_first_depends_on_past: bool = ignore_first_depends_on_past
self.wait_for_downstream: bool = wait_for_downstream
Expand Down
77 changes: 43 additions & 34 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,25 @@ def update_state(
# Callback to execute in case of Task Failures
callback: Optional[DagCallbackRequest] = None

class _UnfinishedStates(NamedTuple):
tis: Sequence[TI]

@classmethod
def calculate(cls, unfinished_tis: Sequence[TI]) -> "_UnfinishedStates":
return cls(tis=unfinished_tis)

@property
def should_schedule(self) -> bool:
return (
bool(self.tis)
and all(not t.task.depends_on_past for t in self.tis)
and all(t.task.max_active_tis_per_dag is None for t in self.tis)
and all(t.state != TaskInstanceState.DEFERRED for t in self.tis)
)

def recalculate(self) -> "_UnfinishedStates":
return self._replace(tis=[t for t in self.tis if t.state in State.unfinished])

start_dttm = timezone.utcnow()
self.last_scheduling_decision = start_dttm
with Stats.timer(f"dagrun.dependency-check.{self.dag_id}"):
Expand All @@ -531,25 +550,23 @@ def update_state(
schedulable_tis = info.schedulable_tis
changed_tis = info.changed_tis
finished_tis = info.finished_tis
unfinished_tis = info.unfinished_tis

none_depends_on_past = all(not t.task.depends_on_past for t in unfinished_tis)
none_task_concurrency = all(t.task.max_active_tis_per_dag is None for t in unfinished_tis)
none_deferred = all(t.state != State.DEFERRED for t in unfinished_tis)
unfinished = _UnfinishedStates.calculate(info.unfinished_tis)

if unfinished_tis and none_depends_on_past and none_task_concurrency and none_deferred:
if unfinished.should_schedule:
are_runnable_tasks = schedulable_tis or changed_tis
# small speed up
are_runnable_tasks = (
schedulable_tis
or self._are_premature_tis(unfinished_tis, finished_tis, session)
or changed_tis
)
if not are_runnable_tasks:
are_runnable_tasks, changed_by_upstream = self._are_premature_tis(
unfinished.tis, finished_tis, session
)
if changed_by_upstream: # Something changed, we need to recalculate!
unfinished = unfinished.recalculate()

leaf_task_ids = {t.task_id for t in dag.leaves}
leaf_tis = [ti for ti in tis if ti.task_id in leaf_task_ids if ti.state != TaskInstanceState.REMOVED]

# if all roots finished and at least one failed, the run failed
if not unfinished_tis and any(leaf_ti.state in State.failed_states for leaf_ti in leaf_tis):
if not unfinished.tis and any(leaf_ti.state in State.failed_states for leaf_ti in leaf_tis):
self.log.error('Marking run %s failed', self)
self.set_state(DagRunState.FAILED)
if execute_callbacks:
Expand All @@ -564,7 +581,7 @@ def update_state(
)

# if all leaves succeeded and no unfinished tasks, the run succeeded
elif not unfinished_tis and all(leaf_ti.state in State.success_states for leaf_ti in leaf_tis):
elif not unfinished.tis and all(leaf_ti.state in State.success_states for leaf_ti in leaf_tis):
self.log.info('Marking run %s successful', self)
self.set_state(DagRunState.SUCCESS)
if execute_callbacks:
Expand All @@ -579,13 +596,7 @@ def update_state(
)

# if *all tasks* are deadlocked, the run failed
elif (
unfinished_tis
and none_depends_on_past
and none_task_concurrency
and none_deferred
and not are_runnable_tasks
):
elif unfinished.should_schedule and not are_runnable_tasks:
self.log.error('Deadlock; marking run %s failed', self)
self.set_state(DagRunState.FAILED)
if execute_callbacks:
Expand Down Expand Up @@ -744,24 +755,22 @@ def _get_ready_tis(

def _are_premature_tis(
self,
unfinished_tis: List[TI],
unfinished_tis: Sequence[TI],
finished_tis: List[TI],
session: Session,
) -> bool:
) -> Tuple[bool, bool]:
dep_context = DepContext(
flag_upstream_failed=True,
ignore_in_retry_period=True,
ignore_in_reschedule_period=True,
finished_tis=finished_tis,
)
# there might be runnable tasks that are up for retry and for some reason(retry delay, etc) are
# not ready yet so we set the flags to count them in
for ut in unfinished_tis:
if ut.are_dependencies_met(
dep_context=DepContext(
flag_upstream_failed=True,
ignore_in_retry_period=True,
ignore_in_reschedule_period=True,
finished_tis=finished_tis,
),
session=session,
):
return True
return False
return (
any(ut.are_dependencies_met(dep_context=dep_context, session=session) for ut in unfinished_tis),
dep_context.have_changed_ti_states,
)

def _emit_true_scheduling_delay_stats_for_finished_state(self, finished_tis: List[TI]) -> None:
"""
Expand Down
7 changes: 6 additions & 1 deletion airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,13 +947,17 @@ def key(self) -> TaskInstanceKey:
return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, self.try_number, self.map_index)

@provide_session
def set_state(self, state: Optional[str], session: Session = NEW_SESSION) -> None:
def set_state(self, state: Optional[str], session: Session = NEW_SESSION) -> bool:
"""
Set TaskInstance state.
:param state: State to set for the TI
:param session: SQLAlchemy ORM Session
:return: Was the state changed
"""
if self.state == state:
return False

current_time = timezone.utcnow()
self.log.debug("Setting task state for %s to %s", self, state)
self.state = state
Expand All @@ -962,6 +966,7 @@ def set_state(self, state: Optional[str], session: Session = NEW_SESSION) -> Non
self.end_date = self.end_date or current_time
self.duration = (self.end_date - self.start_date).total_seconds()
session.merge(self)
return True

@property
def is_premature(self) -> bool:
Expand Down
3 changes: 3 additions & 0 deletions airflow/ti_deps/dep_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ class DepContext:
ignore_unmapped_tasks: bool = False
finished_tis: Optional[List["TaskInstance"]] = None

have_changed_ti_states: bool = False
"""Have any of the TIs state's been changed as a result of evaluating dependencies"""

def ensure_finished_tis(self, dag_run: "DagRun", session: Session) -> List["TaskInstance"]:
"""
This method makes sure finished_tis is populated if it's currently None.
Expand Down
35 changes: 18 additions & 17 deletions airflow/ti_deps/deps/trigger_rule_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def _get_dep_statuses(self, ti, session, dep_context: DepContext):
upstream_failed=upstream_failed,
done=done,
flag_upstream_failed=dep_context.flag_upstream_failed,
dep_context=dep_context,
session=session,
)

Expand Down Expand Up @@ -115,13 +116,14 @@ def _count_upstreams(ti: "TaskInstance", *, session: "Session"):
@provide_session
def _evaluate_trigger_rule(
self,
ti,
ti: "TaskInstance",
successes,
skipped,
failed,
upstream_failed,
done,
flag_upstream_failed,
dep_context: DepContext,
session: "Session" = NEW_SESSION,
):
"""
Expand All @@ -142,7 +144,7 @@ def _evaluate_trigger_rule(
"""
task = ti.task
upstream = self._count_upstreams(ti, session=session)
trigger_rule: TR = task.trigger_rule
trigger_rule = task.trigger_rule
upstream_done = done >= upstream
upstream_tasks_state = {
"total": upstream,
Expand All @@ -152,43 +154,42 @@ def _evaluate_trigger_rule(
"upstream_failed": upstream_failed,
"done": done,
}
# TODO(aoen): Ideally each individual trigger rules would be its own class, but
# this isn't very feasible at the moment since the database queries need to be
# bundled together for efficiency.
# handling instant state assignment based on trigger rules
changed: bool = False
if flag_upstream_failed:
if trigger_rule == TR.ALL_SUCCESS:
if upstream_failed or failed:
ti.set_state(State.UPSTREAM_FAILED, session)
changed = ti.set_state(State.UPSTREAM_FAILED, session)
elif skipped:
ti.set_state(State.SKIPPED, session)
changed = ti.set_state(State.SKIPPED, session)
elif trigger_rule == TR.ALL_FAILED:
if successes or skipped:
ti.set_state(State.SKIPPED, session)
changed = ti.set_state(State.SKIPPED, session)
elif trigger_rule == TR.ONE_SUCCESS:
if upstream_done and done == skipped:
# if upstream is done and all are skipped mark as skipped
ti.set_state(State.SKIPPED, session)
changed = ti.set_state(State.SKIPPED, session)
elif upstream_done and successes <= 0:
# if upstream is done and there are no successes mark as upstream failed
ti.set_state(State.UPSTREAM_FAILED, session)
changed = ti.set_state(State.UPSTREAM_FAILED, session)
elif trigger_rule == TR.ONE_FAILED:
if upstream_done and not (failed or upstream_failed):
ti.set_state(State.SKIPPED, session)
changed = ti.set_state(State.SKIPPED, session)
elif trigger_rule == TR.NONE_FAILED:
if upstream_failed or failed:
ti.set_state(State.UPSTREAM_FAILED, session)
changed = ti.set_state(State.UPSTREAM_FAILED, session)
elif trigger_rule == TR.NONE_FAILED_MIN_ONE_SUCCESS:
if upstream_failed or failed:
ti.set_state(State.UPSTREAM_FAILED, session)
changed = ti.set_state(State.UPSTREAM_FAILED, session)
elif skipped == upstream:
ti.set_state(State.SKIPPED, session)
changed = ti.set_state(State.SKIPPED, session)
elif trigger_rule == TR.NONE_SKIPPED:
if skipped:
ti.set_state(State.SKIPPED, session)
changed = ti.set_state(State.SKIPPED, session)
elif trigger_rule == TR.ALL_SKIPPED:
if successes or failed:
ti.set_state(State.SKIPPED, session)
changed = ti.set_state(State.SKIPPED, session)
if changed:
dep_context.have_changed_ti_states = True

if trigger_rule == TR.ONE_SUCCESS:
if successes <= 0:
Expand Down
40 changes: 35 additions & 5 deletions tests/models/test_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,19 +289,20 @@ def test_dagrun_deadlock(self, session):
execution_date=now,
data_interval=dag.timetable.infer_manual_data_interval(run_after=now),
start_date=now,
session=session,
)

ti_op1 = dr.get_task_instance(task_id=op1.task_id)
ti_op1: TI = dr.get_task_instance(task_id=op1.task_id, session=session)
ti_op2: TI = dr.get_task_instance(task_id=op2.task_id, session=session)
ti_op1.set_state(state=TaskInstanceState.SUCCESS, session=session)
ti_op2 = dr.get_task_instance(task_id=op2.task_id)
ti_op2.set_state(state=None, session=session)

dr.update_state()
dr.update_state(session=session)
assert dr.state == DagRunState.RUNNING

ti_op2.set_state(state=None, session=session)
op2.trigger_rule = 'invalid'
dr.update_state()
op2.trigger_rule = 'invalid' # type: ignore
dr.update_state(session=session)
assert dr.state == DagRunState.FAILED

def test_dagrun_no_deadlock_with_shutdown(self, session):
Expand Down Expand Up @@ -1868,3 +1869,32 @@ def task_1(args_0):
("task_4", 1): None,
("task_4", 2): None,
}


def test_mapped_skip_upstream_not_deadlock(dag_maker):
with dag_maker() as dag:

@dag.task
def add_one(x: int):
return x + 1

@dag.task
def say_hi():
print("Hi")

added_values = add_one.expand(x=[])
added_more_values = add_one.expand(x=[])
say_hi() >> added_values
added_values >> added_more_values

dr = dag_maker.create_dagrun()

session = dag_maker.session
tis = {ti.task_id: ti for ti in dr.task_instances}

tis['say_hi'].state = TaskInstanceState.SUCCESS
session.flush()

dr.update_state(session=session)
assert dr.state == DagRunState.SUCCESS
assert tis['add_one__1'].state == TaskInstanceState.SKIPPED
2 changes: 2 additions & 0 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
from airflow.sensors.python import PythonSensor
from airflow.serialization.serialized_objects import SerializedBaseOperator
from airflow.stats import Stats
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS
from airflow.ti_deps.dependencies_states import RUNNABLE_STATES
from airflow.ti_deps.deps.base_ti_dep import TIDepStatus
Expand Down Expand Up @@ -1152,6 +1153,7 @@ def test_check_task_dependencies(
failed=failed,
upstream_failed=upstream_failed,
done=done,
dep_context=DepContext(),
flag_upstream_failed=flag_upstream_failed,
)
completed = all(dep.passed for dep in dep_results)
Expand Down
Loading

0 comments on commit 5697e9f

Please sign in to comment.