Skip to content

Commit

Permalink
Allow Operators to specify SKIPPED status internally
Browse files Browse the repository at this point in the history
* Added ability to skip DAG elements based on raised Exception

* Added nose-parameterized to test dependencies

* Fix for broken mysql test - provided by jlowin
  • Loading branch information
withnale authored and jlowin committed Apr 6, 2016
1 parent 8afee07 commit 81ff5cc
Show file tree
Hide file tree
Showing 8 changed files with 258 additions and 47 deletions.
63 changes: 63 additions & 0 deletions airflow/example_dags/example_skip_dag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from airflow.operators import DummyOperator
from airflow.models import DAG
from datetime import datetime, timedelta
from airflow.exceptions import AirflowSkipException

seven_days_ago = datetime.combine(datetime.today() - timedelta(1),
datetime.min.time())
args = {
'owner': 'airflow',
'start_date': seven_days_ago,
}


# Create some placeholder operators
class DummySkipOperator(DummyOperator):
ui_color = '#e8b7e4'

def execute(self, context):
raise AirflowSkipException


dag = DAG(dag_id='example_skip_dag', default_args=args)


def create_test_pipeline(suffix, trigger_rule, dag):

skip_operator = DummySkipOperator(task_id='skip_operator_{}'.format(suffix), dag=dag)

always_true = DummyOperator(task_id='always_true_{}'.format(suffix), dag=dag)

join = DummyOperator(task_id=trigger_rule, dag=dag, trigger_rule=trigger_rule)


op = MyEmrOperator(task_id='my_task_id', dag=dag,
template='my_jinja_template.conf',
params={ 'param1': '{{ ti.xcom_pull(...) }}' }
)

join.set_upstream(skip_operator)
join.set_upstream(always_true)

final = DummyOperator(task_id='final_{}'.format(suffix), dag=dag)
final.set_upstream(join)


create_test_pipeline('1', 'all_success', dag)
create_test_pipeline('2', 'one_success', dag)


16 changes: 14 additions & 2 deletions airflow/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,22 @@

#
# Any AirflowException raised is expected to cause the TaskInstance to be marked in an ERROR state
#
class AirflowException(Exception):
pass


class AirflowSensorTimeout(Exception):
class AirflowSensorTimeout(AirflowException):
pass


class AirflowTaskTimeout(Exception):
class AirflowTaskTimeout(AirflowException):
pass


#
# Any AirflowSkipException raised is expected to cause the TaskInstance to be marked in an SKIPPED state
#
class AirflowSkipException(AirflowException):
pass

105 changes: 69 additions & 36 deletions airflow/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from airflow import settings, utils
from airflow.executors import DEFAULT_EXECUTOR, LocalExecutor
from airflow import configuration
from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.utils.dates import cron_presets, date_range as utils_date_range
from airflow.utils.db import provide_session
from airflow.utils.decorators import apply_defaults
Expand Down Expand Up @@ -856,6 +856,62 @@ def are_dependents_done(self, session=None):
count = ti[0][0]
return count == len(task.downstream_task_ids)

@provide_session
def evaluate_trigger_rule(self, successes, skipped, failed,
upstream_failed, done,
flag_upstream_failed, session=None):
"""
Returns a boolean on whether the current task can be scheduled
for execution based on its trigger_rule.
:param flag_upstream_failed: This is a hack to generate
the upstream_failed state creation while checking to see
whether the task instance is runnable. It was the shortest
path to add the feature
:type flag_upstream_failed: boolean
:param successes: Number of successful upstream tasks
:type successes: boolean
:param skipped: Number of skipped upstream tasks
:type skipped: boolean
:param failed: Number of failed upstream tasks
:type failed: boolean
:param upstream_failed: Number of upstream_failed upstream tasks
:type upstream_failed: boolean
:param done: Number of completed upstream tasks
:type done: boolean
"""
TR = TriggerRule

task = self.task
upstream = len(task.upstream_task_ids)
tr = task.trigger_rule
upstream_done = done >= upstream

# handling instant state assignment based on trigger rules
if flag_upstream_failed:
if tr == TR.ALL_SUCCESS:
if upstream_failed or failed:
self.set_state(State.UPSTREAM_FAILED, session)
elif skipped:
self.set_state(State.SKIPPED, session)
elif tr == TR.ALL_FAILED:
if successes or skipped:
self.set_state(State.SKIPPED, session)
elif tr == TR.ONE_SUCCESS:
if upstream_done and not successes:
self.set_state(State.SKIPPED, session)
elif tr == TR.ONE_FAILED:
if upstream_done and not (failed or upstream_failed):
self.set_state(State.SKIPPED, session)

return (
(tr == TR.ONE_SUCCESS and successes > 0) or
(tr == TR.ONE_FAILED and (failed or upstream_failed)) or
(tr == TR.ALL_SUCCESS and successes >= upstream) or
(tr == TR.ALL_FAILED and failed + upstream_failed >= upstream) or
(tr == TR.ALL_DONE and upstream_done)
)

@provide_session
def are_dependencies_met(
self,
Expand Down Expand Up @@ -935,41 +991,16 @@ def are_dependencies_met(
State.UPSTREAM_FAILED, State.SKIPPED]),
)
)
successes, skipped, failed, upstream_failed, done = qry.first()
upstream = len(task.upstream_task_ids)
tr = task.trigger_rule
upstream_done = done >= upstream

# handling instant state assignment based on trigger rules
if flag_upstream_failed:
if tr == TR.ALL_SUCCESS:
if upstream_failed or failed:
self.set_state(State.UPSTREAM_FAILED, session)
elif skipped:
self.set_state(State.SKIPPED, session)
elif tr == TR.ALL_FAILED:
if successes or skipped:
self.set_state(State.SKIPPED, session)
elif tr == TR.ONE_SUCCESS:
if upstream_done and not successes:
self.set_state(State.SKIPPED, session)
elif tr == TR.ONE_FAILED:
if upstream_done and not(failed or upstream_failed):
self.set_state(State.SKIPPED, session)

if (
(tr == TR.ONE_SUCCESS and successes) or
(tr == TR.ONE_FAILED and (failed or upstream_failed)) or
(tr == TR.ALL_SUCCESS and successes >= upstream) or
(tr == TR.ALL_FAILED and failed + upstream_failed >= upstream) or
(tr == TR.ALL_DONE and upstream_done)
):
return True

successes, skipped, failed, upstream_failed, done = qry.first()
satisfied = self.evaluate_trigger_rule(
session=session, successes=successes, skipped=skipped,
failed=failed, upstream_failed=upstream_failed, done=done,
flag_upstream_failed=flag_upstream_failed)
session.commit()
if verbose:
logging.warning("Trigger rule `{}` not satisfied".format(tr))
return False
if verbose and not satisfied:
logging.warning("Trigger rule `{}` not satisfied".format(task.trigger_rule))
return satisfied

def __repr__(self):
return (
Expand Down Expand Up @@ -1141,16 +1172,18 @@ def signal_handler(signum, frame):
self.xcom_push(key=XCOM_RETURN_KEY, value=result)

task_copy.post_execute(context=context)
self.state = State.SUCCESS
except AirflowSkipException:
self.state = State.SKIPPED
except (Exception, KeyboardInterrupt) as e:
self.handle_failure(e, test_mode, context)
raise

# Recording SUCCESS
self.end_date = datetime.now()
self.set_duration()
self.state = State.SUCCESS
if not test_mode:
session.add(Log(State.SUCCESS, self))
session.add(Log(self.state, self))
session.merge(self)
session.commit()

Expand Down
11 changes: 9 additions & 2 deletions airflow/operators/sensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from time import sleep

from airflow import hooks, settings
from airflow.exceptions import AirflowException, AirflowSensorTimeout
from airflow.exceptions import AirflowException, AirflowSensorTimeout, AirflowSkipException
from airflow.models import BaseOperator, TaskInstance, Connection as DB
from airflow.hooks import BaseHook
from airflow.utils.state import State
Expand All @@ -22,6 +22,8 @@ class BaseSensorOperator(BaseOperator):
Sensor operators keep executing at a time interval and succeed when
a criteria is met and fail if and when they time out.
:param soft_fail: Set to true to mark the task as SKIPPED on failure
:type soft_fail: bool
:param poke_interval: Time in seconds that the job should wait in
between each tries
:type poke_interval: int
Expand All @@ -35,9 +37,11 @@ def __init__(
self,
poke_interval=60,
timeout=60*60*24*7,
soft_fail=False,
*args, **kwargs):
super(BaseSensorOperator, self).__init__(*args, **kwargs)
self.poke_interval = poke_interval
self.soft_fail = soft_fail
self.timeout = timeout

def poke(self, context):
Expand All @@ -52,7 +56,10 @@ def execute(self, context):
while not self.poke(context):
sleep(self.poke_interval)
if (datetime.now() - started_at).seconds > self.timeout:
raise AirflowSensorTimeout('Snap. Time is OUT.')
if self.soft_fail:
raise AirflowSkipException('Snap. Time is OUT.')
else:
raise AirflowSensorTimeout('Snap. Time is OUT.')
logging.info("Success criteria met. Exiting.")


Expand Down
3 changes: 2 additions & 1 deletion scripts/ci/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ ipython
jinja2
markdown
nose
nose-parameterized
nose-exclude
pandas
pygments
Expand Down Expand Up @@ -55,4 +56,4 @@ flask-bcrypt
mock
hdfs
thrift_sasl
impyla
impyla
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def run(self):
qds = ['qds-sdk>=1.9.0']

all_dbs = postgres + mysql + hive + mssql + hdfs + vertica
devel = ['lxml>=3.3.4', 'nose', 'mock']
devel = ['lxml>=3.3.4', 'nose', 'nose-parameterized', 'mock']
devel_minreq = devel + mysql + doc + password + s3
devel_hadoop = devel_minreq + hive + hdfs + webhdfs + kerberos
devel_all = devel + all_dbs + doc + samba + s3 + slack + crypto + oracle + docker
Expand Down
5 changes: 3 additions & 2 deletions tests/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,6 @@ def test_scheduler_start_date(self):
Test that the scheduler respects start_dates, even when DAGS have run
"""

session = settings.Session()

dag_id = 'test_start_date_scheduling'
dag = self.dagbag.get_dag(dag_id)
dag.clear()
Expand All @@ -353,6 +351,7 @@ def test_scheduler_start_date(self):
scheduler.run()

# zero tasks ran
session = settings.Session()
self.assertEqual(
len(session.query(TI).filter(TI.dag_id == dag_id).all()), 0)

Expand All @@ -367,12 +366,14 @@ def test_scheduler_start_date(self):
backfill.run()

# one task ran
session = settings.Session()
self.assertEqual(
len(session.query(TI).filter(TI.dag_id == dag_id).all()), 1)

scheduler = SchedulerJob(dag_id, num_runs=2)
scheduler.run()

# still one task
session = settings.Session()
self.assertEqual(
len(session.query(TI).filter(TI.dag_id == dag_id).all()), 1)
Loading

0 comments on commit 81ff5cc

Please sign in to comment.