Skip to content

Commit

Permalink
Use dag_maker fixture in test_processor.py (apache#17506)
Browse files Browse the repository at this point in the history
This change applies dag_maker fixture in test_process.py

fixup! Use dag_maker fixture in test_processor.py

fixup! fixup! Use dag_maker fixture in test_processor.py
  • Loading branch information
ephraimbuddy authored Aug 11, 2021
1 parent 3773224 commit a1834ce
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 166 deletions.
82 changes: 82 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,30 @@ def app():

@pytest.fixture
def dag_maker(request):
"""
The dag_maker helps us to create DAG & DagModel automatically.
You have to use the dag_maker as a context manager and it takes
the same argument as DAG::
with dag_maker(dag_id="mydag") as dag:
task1 = DummyOperator(task_id='mytask')
task2 = DummyOperator(task_id='mytask2')
If the DagModel you want to use needs different parameters than the one
automatically created by the dag_maker, you have to update the DagModel as below::
dag_maker.dag_model.is_active = False
session.merge(dag_maker.dag_model)
session.commit()
For any test you use the dag_maker, make sure to create a DagRun::
dag_maker.create_dagrun()
The dag_maker.create_dagrun takes the same arguments as dag.create_dagrun
"""
from airflow.models import DAG, DagModel
from airflow.utils import timezone
from airflow.utils.session import provide_session
Expand Down Expand Up @@ -473,7 +497,12 @@ def __call__(self, dag_id='test_dag', session=None, **kwargs):
self.kwargs = kwargs
self.session = session
self.start_date = self.kwargs.get('start_date', None)
default_args = kwargs.get('default_args', None)
if default_args and not self.start_date:
if 'start_date' in default_args:
self.start_date = default_args.get('start_date')
if not self.start_date:

if hasattr(request.module, 'DEFAULT_DATE'):
self.start_date = getattr(request.module, 'DEFAULT_DATE')
else:
Expand All @@ -484,3 +513,56 @@ def __call__(self, dag_id='test_dag', session=None, **kwargs):
return self

return DagFactory()


@pytest.fixture
def create_dummy_dag(dag_maker):
"""
This fixture creates a `DAG` with a single `DummyOperator` task.
DagRun and DagModel is also created.
Apart from the already existing arguments, any other argument in kwargs
is passed to the DAG and not to the DummyOperator task.
If you have an argument that you want to pass to the DummyOperator that
is not here, please use `default_args` so that the DAG will pass it to the
Task::
dag, task = create_dummy_dag(default_args={'start_date':timezone.datetime(2016, 1, 1)})
You cannot be able to alter the created DagRun or DagModel, use `dag_maker` fixture instead.
"""
from airflow.operators.dummy import DummyOperator
from airflow.utils.types import DagRunType

def create_dag(
dag_id='dag',
task_id='op1',
task_concurrency=16,
pool='default_pool',
executor_config={},
trigger_rule='all_done',
on_success_callback=None,
on_execute_callback=None,
on_failure_callback=None,
on_retry_callback=None,
email=None,
**kwargs,
):
with dag_maker(dag_id, **kwargs) as dag:
op = DummyOperator(
task_id=task_id,
task_concurrency=task_concurrency,
executor_config=executor_config,
on_success_callback=on_success_callback,
on_execute_callback=on_execute_callback,
on_failure_callback=on_failure_callback,
on_retry_callback=on_retry_callback,
email=email,
pool=pool,
trigger_rule=trigger_rule,
)
dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
return dag, op

return create_dag
98 changes: 35 additions & 63 deletions tests/dag_processing/test_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import datetime
import os
from datetime import timedelta
from tempfile import NamedTemporaryFile
from unittest import mock
from unittest.mock import MagicMock, patch
Expand All @@ -30,7 +29,7 @@
from airflow import settings
from airflow.configuration import conf
from airflow.dag_processing.processor import DagFileProcessor
from airflow.models import DAG, DagBag, DagModel, SlaMiss, TaskInstance, errors
from airflow.models import DagBag, SlaMiss, TaskInstance, errors
from airflow.models.taskinstance import SimpleTaskInstance
from airflow.operators.dummy import DummyOperator
from airflow.utils import timezone
Expand Down Expand Up @@ -97,34 +96,12 @@ def teardown_method(self) -> None:
self.scheduler_job = None
self.clean_db()

def create_test_dag(self, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + timedelta(hours=1), **kwargs):
dag = DAG(
dag_id='test_scheduler_reschedule',
start_date=start_date,
# Make sure it only creates a single DAG Run
end_date=end_date,
)
dag.clear()
dag.is_subdag = False
with create_session() as session:
orm_dag = DagModel(dag_id=dag.dag_id, is_paused=False)
session.merge(orm_dag)
session.commit()
return dag

@classmethod
def setup_class(cls):
# Ensure the DAGs we are looking at from the DB are up-to-date
non_serialized_dagbag = DagBag(read_dags_from_db=False, include_examples=False)
non_serialized_dagbag.sync_to_db()
cls.dagbag = DagBag(read_dags_from_db=True)

def _process_file(self, file_path, session):
dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())

dag_file_processor.process_file(file_path, [], False, session)

def test_dag_file_processor_sla_miss_callback(self):
def test_dag_file_processor_sla_miss_callback(self, create_dummy_dag):
"""
Test that the dag file processor calls the sla miss callback
"""
Expand All @@ -135,14 +112,13 @@ def test_dag_file_processor_sla_miss_callback(self):
# Create dag with a start of 1 day ago, but an sla of 0
# so we'll already have an sla_miss on the books.
test_start_date = days_ago(1)
dag = DAG(
dag, task = create_dummy_dag(
dag_id='test_sla_miss',
task_id='dummy',
sla_miss_callback=sla_callback,
default_args={'start_date': test_start_date, 'sla': datetime.timedelta()},
)

task = DummyOperator(task_id='dummy', dag=dag, owner='airflow')

session.merge(TaskInstance(task=task, execution_date=test_start_date, state='success'))

session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date))
Expand All @@ -152,7 +128,7 @@ def test_dag_file_processor_sla_miss_callback(self):

assert sla_callback.called

def test_dag_file_processor_sla_miss_callback_invalid_sla(self):
def test_dag_file_processor_sla_miss_callback_invalid_sla(self, create_dummy_dag):
"""
Test that the dag file processor does not call the sla miss callback when
given an invalid sla
Expand All @@ -165,14 +141,13 @@ def test_dag_file_processor_sla_miss_callback_invalid_sla(self):
# so we'll already have an sla_miss on the books.
# Pass anything besides a timedelta object to the sla argument.
test_start_date = days_ago(1)
dag = DAG(
dag, task = create_dummy_dag(
dag_id='test_sla_miss',
task_id='dummy',
sla_miss_callback=sla_callback,
default_args={'start_date': test_start_date, 'sla': None},
)

task = DummyOperator(task_id='dummy', dag=dag, owner='airflow')

session.merge(TaskInstance(task=task, execution_date=test_start_date, state='success'))

session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date))
Expand All @@ -181,7 +156,7 @@ def test_dag_file_processor_sla_miss_callback_invalid_sla(self):
dag_file_processor.manage_slas(dag=dag, session=session)
sla_callback.assert_not_called()

def test_dag_file_processor_sla_miss_callback_sent_notification(self):
def test_dag_file_processor_sla_miss_callback_sent_notification(self, create_dummy_dag):
"""
Test that the dag file processor does not call the sla_miss_callback when a
notification has already been sent
Expand All @@ -194,14 +169,13 @@ def test_dag_file_processor_sla_miss_callback_sent_notification(self):
# Create dag with a start of 2 days ago, but an sla of 1 day
# ago so we'll already have an sla_miss on the books
test_start_date = days_ago(2)
dag = DAG(
dag, task = create_dummy_dag(
dag_id='test_sla_miss',
task_id='dummy',
sla_miss_callback=sla_callback,
default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)},
)

task = DummyOperator(task_id='dummy', dag=dag, owner='airflow')

# Create a TaskInstance for two days ago
session.merge(TaskInstance(task=task, execution_date=test_start_date, state='success'))

Expand All @@ -222,7 +196,7 @@ def test_dag_file_processor_sla_miss_callback_sent_notification(self):

sla_callback.assert_not_called()

def test_dag_file_processor_sla_miss_callback_exception(self):
def test_dag_file_processor_sla_miss_callback_exception(self, create_dummy_dag):
"""
Test that the dag file processor gracefully logs an exception if there is a problem
calling the sla_miss_callback
Expand All @@ -232,14 +206,13 @@ def test_dag_file_processor_sla_miss_callback_exception(self):
sla_callback = MagicMock(side_effect=RuntimeError('Could not call function'))

test_start_date = days_ago(2)
dag = DAG(
dag, task = create_dummy_dag(
dag_id='test_sla_miss',
task_id='dummy',
sla_miss_callback=sla_callback,
default_args={'start_date': test_start_date},
default_args={'start_date': test_start_date, 'sla': datetime.timedelta(hours=1)},
)

task = DummyOperator(task_id='dummy', dag=dag, owner='airflow', sla=datetime.timedelta(hours=1))

session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success'))

# Create an SlaMiss where notification was sent, but email was not
Expand All @@ -255,18 +228,18 @@ def test_dag_file_processor_sla_miss_callback_exception(self):
)

@mock.patch('airflow.dag_processing.processor.send_email')
def test_dag_file_processor_only_collect_emails_from_sla_missed_tasks(self, mock_send_email):
def test_dag_file_processor_only_collect_emails_from_sla_missed_tasks(
self, mock_send_email, create_dummy_dag
):
session = settings.Session()

test_start_date = days_ago(2)
dag = DAG(
dag_id='test_sla_miss',
default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)},
)

email1 = '[email protected]'
task = DummyOperator(
task_id='sla_missed', dag=dag, owner='airflow', email=email1, sla=datetime.timedelta(hours=1)
dag, task = create_dummy_dag(
dag_id='test_sla_miss',
task_id='sla_missed',
email=email1,
default_args={'start_date': test_start_date, 'sla': datetime.timedelta(hours=1)},
)

session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success'))
Expand All @@ -288,7 +261,9 @@ def test_dag_file_processor_only_collect_emails_from_sla_missed_tasks(self, mock

@mock.patch('airflow.dag_processing.processor.Stats.incr')
@mock.patch("airflow.utils.email.send_email")
def test_dag_file_processor_sla_miss_email_exception(self, mock_send_email, mock_stats_incr):
def test_dag_file_processor_sla_miss_email_exception(
self, mock_send_email, mock_stats_incr, create_dummy_dag
):
"""
Test that the dag file processor gracefully logs an exception if there is a problem
sending an email
Expand All @@ -299,14 +274,13 @@ def test_dag_file_processor_sla_miss_email_exception(self, mock_send_email, mock
mock_send_email.side_effect = RuntimeError('Could not send an email')

test_start_date = days_ago(2)
dag = DAG(
dag, task = create_dummy_dag(
dag_id='test_sla_miss',
default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)},
)

task = DummyOperator(
task_id='dummy', dag=dag, owner='airflow', email='[email protected]', sla=datetime.timedelta(hours=1)
task_id='dummy',
email='[email protected]',
default_args={'start_date': test_start_date, 'sla': datetime.timedelta(hours=1)},
)
mock_stats_incr.reset_mock()

session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success'))

Expand All @@ -322,21 +296,19 @@ def test_dag_file_processor_sla_miss_email_exception(self, mock_send_email, mock
)
mock_stats_incr.assert_called_once_with('sla_email_notification_failure')

def test_dag_file_processor_sla_miss_deleted_task(self):
def test_dag_file_processor_sla_miss_deleted_task(self, create_dummy_dag):
"""
Test that the dag file processor will not crash when trying to send
sla miss notification for a deleted task
"""
session = settings.Session()

test_start_date = days_ago(2)
dag = DAG(
dag, task = create_dummy_dag(
dag_id='test_sla_miss',
default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)},
)

task = DummyOperator(
task_id='dummy', dag=dag, owner='airflow', email='[email protected]', sla=datetime.timedelta(hours=1)
task_id='dummy',
email='[email protected]',
default_args={'start_date': test_start_date, 'sla': datetime.timedelta(hours=1)},
)

session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success'))
Expand Down
Loading

0 comments on commit a1834ce

Please sign in to comment.