Skip to content

Commit

Permalink
[AIRFLOW-719] Prevent DAGs from ending prematurely
Browse files Browse the repository at this point in the history
DAGs using ALL_SUCCESS and ONE_SUCCESS trigger
rules were ending
prematurely when upstream tasks were skipped.
Changes mean that the
ALL_SUCCESS and ONE_SUCCESS triggers rule
encompasses both SUCCESS and
SKIPPED tasks.

Closes apache#2125 from dhuang/AIRFLOW-719
  • Loading branch information
dhuang authored and bolkedebruin committed Mar 4, 2017
1 parent 7764c75 commit 1fdcf24
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 12 deletions.
6 changes: 3 additions & 3 deletions airflow/ti_deps/deps/trigger_rule_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def _evaluate_trigger_rule(
if tr == TR.ALL_SUCCESS:
if upstream_failed or failed:
ti.set_state(State.UPSTREAM_FAILED, session)
elif skipped:
elif skipped == upstream:
ti.set_state(State.SKIPPED, session)
elif tr == TR.ALL_FAILED:
if successes or skipped:
Expand All @@ -148,7 +148,7 @@ def _evaluate_trigger_rule(
ti.set_state(State.SKIPPED, session)

if tr == TR.ONE_SUCCESS:
if successes <= 0:
if successes <= 0 and skipped <= 0:
yield self._failing_status(
reason="Task's trigger rule '{0}' requires one upstream "
"task success, but none were found. "
Expand All @@ -162,7 +162,7 @@ def _evaluate_trigger_rule(
"upstream_tasks_state={1}, upstream_task_ids={2}"
.format(tr, upstream_tasks_state, task.upstream_task_ids))
elif tr == TR.ALL_SUCCESS:
num_failures = upstream - successes
num_failures = upstream - (successes + skipped)
if num_failures > 0:
yield self._failing_status(
reason="Task's trigger rule '{0}' requires all upstream "
Expand Down
38 changes: 38 additions & 0 deletions tests/dags/test_dagrun_short_circuit_false.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from datetime import datetime

from airflow.models import DAG
from airflow.operators.python_operator import ShortCircuitOperator
from airflow.operators.dummy_operator import DummyOperator


# DAG that has its short circuit op fail and skip multiple downstream tasks
dag = DAG(
dag_id='test_dagrun_short_circuit_false',
start_date=datetime(2017, 1, 1)
)
dag_task1 = ShortCircuitOperator(
task_id='test_short_circuit_false',
dag=dag,
python_callable=lambda: False)
dag_task2 = DummyOperator(
task_id='test_state_skipped1',
dag=dag)
dag_task3 = DummyOperator(
task_id='test_state_skipped2',
dag=dag)
dag_task1.set_downstream(dag_task2)
dag_task2.set_downstream(dag_task3)
79 changes: 70 additions & 9 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from airflow.utils.state import State
from mock import patch
from nose_parameterized import parameterized
from tests.core import TEST_DAG_FOLDER

DEFAULT_DATE = datetime.datetime(2016, 1, 1)
TEST_DAGS_FOLDER = os.path.join(
Expand Down Expand Up @@ -117,13 +118,71 @@ def test_dag_as_context_manager(self):
self.assertEqual(dag.dag_id, 'creating_dag_in_cm')
self.assertEqual(dag.tasks[0].task_id, 'op6')


class DagRunTest(unittest.TestCase):

def setUp(self):
self.dagbag = models.DagBag(dag_folder=TEST_DAG_FOLDER)

def create_dag_run(self, dag_id, state=State.RUNNING, task_states=None):
now = datetime.datetime.now()
dag = self.dagbag.get_dag(dag_id)
dag_run = dag.create_dagrun(
run_id='manual__' + now.isoformat(),
execution_date=now,
start_date=now,
state=State.RUNNING,
external_trigger=False,
)

if task_states is not None:
session = settings.Session()
for task_id, state in task_states.items():
ti = dag_run.get_task_instance(task_id)
ti.set_state(state, session)
session.close()

return dag_run

def test_id_for_date(self):
run_id = models.DagRun.id_for_date(
datetime.datetime(2015, 1, 2, 3, 4, 5, 6, None))
self.assertEqual('scheduled__2015-01-02T03:04:05', run_id, msg=
self.assertEqual(
'scheduled__2015-01-02T03:04:05', run_id,
'Generated run_id did not match expectations: {0}'.format(run_id))

def test_dagrun_running_when_upstream_skipped(self):
"""
Tests that a DAG run is not failed when an upstream task is skipped
"""
initial_task_states = {
'test_short_circuit_false': State.SUCCESS,
'test_state_skipped1': State.SKIPPED,
'test_state_skipped2': State.NONE,
}
# dags/test_dagrun_short_circuit_false.py
dag_run = self.create_dag_run('test_dagrun_short_circuit_false',
state=State.RUNNING,
task_states=initial_task_states)
updated_dag_state = dag_run.update_state()
self.assertEqual(State.RUNNING, updated_dag_state)

def test_dagrun_success_when_all_skipped(self):
"""
Tests that a DAG run succeeds when all tasks are skipped
"""
initial_task_states = {
'test_short_circuit_false': State.SUCCESS,
'test_state_skipped1': State.SKIPPED,
'test_state_skipped2': State.SKIPPED,
}
# dags/test_dagrun_short_circuit_false.py
dag_run = self.create_dag_run('test_dagrun_short_circuit_false',
state=State.RUNNING,
task_states=initial_task_states)
updated_dag_state = dag_run.update_state()
self.assertEqual(State.SUCCESS, updated_dag_state)


class DagBagTest(unittest.TestCase):

Expand Down Expand Up @@ -488,7 +547,7 @@ def test_next_retry_datetime(self):
self.assertEqual(dt, ti.end_date+max_delay)

def test_depends_on_past(self):
dagbag = models.DagBag()
dagbag = models.DagBag(dag_folder=TEST_DAG_FOLDER)
dag = dagbag.get_dag('test_depends_on_past')
dag.clear()
task = dag.tasks[0]
Expand Down Expand Up @@ -517,17 +576,19 @@ def test_depends_on_past(self):
#
# Tests for all_success
#
['all_success', 5, 0, 0, 0, 0, True, None, True],
['all_success', 2, 0, 0, 0, 0, True, None, False],
['all_success', 2, 0, 1, 0, 0, True, ST.UPSTREAM_FAILED, False],
['all_success', 2, 1, 0, 0, 0, True, ST.SKIPPED, False],
['all_success', 5, 0, 0, 0, 5, True, None, True],
['all_success', 2, 0, 0, 0, 2, True, None, False],
['all_success', 2, 0, 1, 0, 3, True, ST.UPSTREAM_FAILED, False],
['all_success', 2, 1, 0, 0, 3, True, None, False],
['all_success', 0, 5, 0, 0, 5, True, ST.SKIPPED, True],
#
# Tests for one_success
#
['one_success', 5, 0, 0, 0, 5, True, None, True],
['one_success', 2, 0, 0, 0, 2, True, None, True],
['one_success', 2, 0, 1, 0, 3, True, None, True],
['one_success', 2, 1, 0, 0, 3, True, None, True],
['one_success', 0, 2, 0, 0, 2, True, None, True],
#
# Tests for all_failed
#
Expand All @@ -539,9 +600,9 @@ def test_depends_on_past(self):
#
# Tests for one_failed
#
['one_failed', 5, 0, 0, 0, 0, True, None, False],
['one_failed', 2, 0, 0, 0, 0, True, None, False],
['one_failed', 2, 0, 1, 0, 0, True, None, True],
['one_failed', 5, 0, 0, 0, 5, True, ST.SKIPPED, False],
['one_failed', 2, 0, 0, 0, 2, True, None, False],
['one_failed', 2, 0, 1, 0, 2, True, None, True],
['one_failed', 2, 1, 0, 0, 3, True, None, False],
['one_failed', 2, 3, 0, 0, 5, True, ST.SKIPPED, False],
#
Expand Down

0 comments on commit 1fdcf24

Please sign in to comment.