Skip to content

Commit

Permalink
Merge pull request apache#1226 from jlowin/subdag_pool
Browse files Browse the repository at this point in the history
Validate subdag pools and add subdag unit tests
  • Loading branch information
bolkedebruin committed Apr 1, 2016
2 parents 3813d51 + 43769bc commit 78f5640
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 1 deletion.
45 changes: 44 additions & 1 deletion airflow/operators/subdag_operator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,21 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.models import BaseOperator, Pool
from airflow.utils.decorators import apply_defaults
from airflow.utils.db import provide_session
from airflow.executors import DEFAULT_EXECUTOR


Expand All @@ -10,6 +25,7 @@ class SubDagOperator(BaseOperator):
ui_color = '#555'
ui_fgcolor = '#fff'

@provide_session
@apply_defaults
def __init__(
self,
Expand All @@ -28,13 +44,40 @@ def __init__(
if 'dag' not in kwargs:
raise AirflowException("Please pass in the `dag` param")
dag = kwargs['dag']
session = kwargs.pop('session')
super(SubDagOperator, self).__init__(*args, **kwargs)

# validate subdag name
if dag.dag_id + '.' + kwargs['task_id'] != subdag.dag_id:
raise AirflowException(
"The subdag's dag_id should have the form "
"'{{parent_dag_id}}.{{this_task_id}}'. Expected "
"'{d}.{t}'; received '{rcvd}'.".format(
d=dag.dag_id, t=kwargs['task_id'], rcvd=subdag.dag_id))

# validate that subdag operator and subdag tasks don't have a
# pool conflict
if self.pool:
pool = (
session
.query(Pool)
.filter(Pool.slots == 1)
.filter(Pool.pool == self.pool)
.first()
)
conflicts = [t for t in subdag.tasks if t.pool == self.pool]
if pool and any(t.pool == self.pool for t in subdag.tasks):
raise AirflowException(
'SubDagOperator {sd} and subdag task{plural} {t} both use '
'pool {p}, but the pool only has 1 slot. The subdag tasks'
'will never run.'.format(
sd=self.task_id,
plural=len(conflicts) > 1,
t=', '.join(t.task_id for t in conflicts),
p=self.pool
)
)

self.subdag = subdag
self.executor = executor

Expand Down
1 change: 1 addition & 0 deletions airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def initdb():
"GROUP BY state"),
)
session.add(chart)
session.commit()


def upgradedb():
Expand Down
1 change: 1 addition & 0 deletions tests/operators/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .docker_operator import *
from .subdag_operator import *
84 changes: 84 additions & 0 deletions tests/operators/subdag_operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
from datetime import datetime

import airflow
from airflow import DAG
from airflow.operators import DummyOperator
from airflow.operators.subdag_operator import SubDagOperator
from airflow.exceptions import AirflowException

default_args = dict(
owner='airflow',
start_date=datetime(2016, 1, 1),
)

class SubDagOperatorTests(unittest.TestCase):

def test_subdag_name(self):
"""
Subdag names must be {parent_dag}.{subdag task}
"""
dag = DAG('parent', default_args=default_args)
subdag_good = DAG('parent.test', default_args=default_args)
subdag_bad1 = DAG('parent.bad', default_args=default_args)
subdag_bad2 = DAG('bad.test', default_args=default_args)
subdag_bad3 = DAG('bad.bad', default_args=default_args)

SubDagOperator(task_id='test', dag=dag, subdag=subdag_good)
self.assertRaises(
AirflowException,
SubDagOperator, task_id='test', dag=dag, subdag=subdag_bad1)
self.assertRaises(
AirflowException,
SubDagOperator, task_id='test', dag=dag, subdag=subdag_bad2)
self.assertRaises(
AirflowException,
SubDagOperator, task_id='test', dag=dag, subdag=subdag_bad3)

def test_subdag_pools(self):
"""
Subdags and subdag tasks can't both have a pool with 1 slot
"""
dag = DAG('parent', default_args=default_args)
subdag = DAG('parent.test', default_args=default_args)

session = airflow.settings.Session()
pool_1 = airflow.models.Pool(pool='test_pool_1', slots=1)
pool_10 = airflow.models.Pool(pool='test_pool_10', slots=10)
session.add(pool_1)
session.add(pool_10)
session.commit()

dummy_1 = DummyOperator(task_id='dummy', dag=subdag, pool='test_pool_1')

self.assertRaises(
AirflowException,
SubDagOperator,
task_id='test', dag=dag, subdag=subdag, pool='test_pool_1')

# recreate dag because failed subdagoperator was already added
dag = DAG('parent', default_args=default_args)
SubDagOperator(
task_id='test', dag=dag, subdag=subdag, pool='test_pool_10')

session.delete(pool_1)
session.delete(pool_10)
session.commit()


if __name__ == "__main__":
unittest.main()

0 comments on commit 78f5640

Please sign in to comment.