diff --git a/tests/conftest.py b/tests/conftest.py index 197a57380202e..c7685d4287751 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -428,6 +428,30 @@ def app(): @pytest.fixture def dag_maker(request): + """ + The dag_maker helps us to create DAG & DagModel automatically. + + You have to use the dag_maker as a context manager and it takes + the same argument as DAG:: + + with dag_maker(dag_id="mydag") as dag: + task1 = DummyOperator(task_id='mytask') + task2 = DummyOperator(task_id='mytask2') + + If the DagModel you want to use needs different parameters than the one + automatically created by the dag_maker, you have to update the DagModel as below:: + + dag_maker.dag_model.is_active = False + session.merge(dag_maker.dag_model) + session.commit() + + For any test you use the dag_maker, make sure to create a DagRun:: + + dag_maker.create_dagrun() + + The dag_maker.create_dagrun takes the same arguments as dag.create_dagrun + + """ from airflow.models import DAG, DagModel from airflow.utils import timezone from airflow.utils.session import provide_session @@ -473,7 +497,12 @@ def __call__(self, dag_id='test_dag', session=None, **kwargs): self.kwargs = kwargs self.session = session self.start_date = self.kwargs.get('start_date', None) + default_args = kwargs.get('default_args', None) + if default_args and not self.start_date: + if 'start_date' in default_args: + self.start_date = default_args.get('start_date') if not self.start_date: + if hasattr(request.module, 'DEFAULT_DATE'): self.start_date = getattr(request.module, 'DEFAULT_DATE') else: @@ -484,3 +513,56 @@ def __call__(self, dag_id='test_dag', session=None, **kwargs): return self return DagFactory() + + +@pytest.fixture +def create_dummy_dag(dag_maker): + """ + This fixture creates a `DAG` with a single `DummyOperator` task. + DagRun and DagModel is also created. + + Apart from the already existing arguments, any other argument in kwargs + is passed to the DAG and not to the DummyOperator task. + + If you have an argument that you want to pass to the DummyOperator that + is not here, please use `default_args` so that the DAG will pass it to the + Task:: + + dag, task = create_dummy_dag(default_args={'start_date':timezone.datetime(2016, 1, 1)}) + + You cannot be able to alter the created DagRun or DagModel, use `dag_maker` fixture instead. + """ + from airflow.operators.dummy import DummyOperator + from airflow.utils.types import DagRunType + + def create_dag( + dag_id='dag', + task_id='op1', + task_concurrency=16, + pool='default_pool', + executor_config={}, + trigger_rule='all_done', + on_success_callback=None, + on_execute_callback=None, + on_failure_callback=None, + on_retry_callback=None, + email=None, + **kwargs, + ): + with dag_maker(dag_id, **kwargs) as dag: + op = DummyOperator( + task_id=task_id, + task_concurrency=task_concurrency, + executor_config=executor_config, + on_success_callback=on_success_callback, + on_execute_callback=on_execute_callback, + on_failure_callback=on_failure_callback, + on_retry_callback=on_retry_callback, + email=email, + pool=pool, + trigger_rule=trigger_rule, + ) + dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) + return dag, op + + return create_dag diff --git a/tests/dag_processing/test_processor.py b/tests/dag_processing/test_processor.py index fcfd0f628ebe6..d99a8cdf76ef9 100644 --- a/tests/dag_processing/test_processor.py +++ b/tests/dag_processing/test_processor.py @@ -19,7 +19,6 @@ import datetime import os -from datetime import timedelta from tempfile import NamedTemporaryFile from unittest import mock from unittest.mock import MagicMock, patch @@ -30,7 +29,7 @@ from airflow import settings from airflow.configuration import conf from airflow.dag_processing.processor import DagFileProcessor -from airflow.models import DAG, DagBag, DagModel, SlaMiss, TaskInstance, errors +from airflow.models import DagBag, SlaMiss, TaskInstance, errors from airflow.models.taskinstance import SimpleTaskInstance from airflow.operators.dummy import DummyOperator from airflow.utils import timezone @@ -97,34 +96,12 @@ def teardown_method(self) -> None: self.scheduler_job = None self.clean_db() - def create_test_dag(self, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + timedelta(hours=1), **kwargs): - dag = DAG( - dag_id='test_scheduler_reschedule', - start_date=start_date, - # Make sure it only creates a single DAG Run - end_date=end_date, - ) - dag.clear() - dag.is_subdag = False - with create_session() as session: - orm_dag = DagModel(dag_id=dag.dag_id, is_paused=False) - session.merge(orm_dag) - session.commit() - return dag - - @classmethod - def setup_class(cls): - # Ensure the DAGs we are looking at from the DB are up-to-date - non_serialized_dagbag = DagBag(read_dags_from_db=False, include_examples=False) - non_serialized_dagbag.sync_to_db() - cls.dagbag = DagBag(read_dags_from_db=True) - def _process_file(self, file_path, session): dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) dag_file_processor.process_file(file_path, [], False, session) - def test_dag_file_processor_sla_miss_callback(self): + def test_dag_file_processor_sla_miss_callback(self, create_dummy_dag): """ Test that the dag file processor calls the sla miss callback """ @@ -135,14 +112,13 @@ def test_dag_file_processor_sla_miss_callback(self): # Create dag with a start of 1 day ago, but an sla of 0 # so we'll already have an sla_miss on the books. test_start_date = days_ago(1) - dag = DAG( + dag, task = create_dummy_dag( dag_id='test_sla_miss', + task_id='dummy', sla_miss_callback=sla_callback, default_args={'start_date': test_start_date, 'sla': datetime.timedelta()}, ) - task = DummyOperator(task_id='dummy', dag=dag, owner='airflow') - session.merge(TaskInstance(task=task, execution_date=test_start_date, state='success')) session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date)) @@ -152,7 +128,7 @@ def test_dag_file_processor_sla_miss_callback(self): assert sla_callback.called - def test_dag_file_processor_sla_miss_callback_invalid_sla(self): + def test_dag_file_processor_sla_miss_callback_invalid_sla(self, create_dummy_dag): """ Test that the dag file processor does not call the sla miss callback when given an invalid sla @@ -165,14 +141,13 @@ def test_dag_file_processor_sla_miss_callback_invalid_sla(self): # so we'll already have an sla_miss on the books. # Pass anything besides a timedelta object to the sla argument. test_start_date = days_ago(1) - dag = DAG( + dag, task = create_dummy_dag( dag_id='test_sla_miss', + task_id='dummy', sla_miss_callback=sla_callback, default_args={'start_date': test_start_date, 'sla': None}, ) - task = DummyOperator(task_id='dummy', dag=dag, owner='airflow') - session.merge(TaskInstance(task=task, execution_date=test_start_date, state='success')) session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date)) @@ -181,7 +156,7 @@ def test_dag_file_processor_sla_miss_callback_invalid_sla(self): dag_file_processor.manage_slas(dag=dag, session=session) sla_callback.assert_not_called() - def test_dag_file_processor_sla_miss_callback_sent_notification(self): + def test_dag_file_processor_sla_miss_callback_sent_notification(self, create_dummy_dag): """ Test that the dag file processor does not call the sla_miss_callback when a notification has already been sent @@ -194,14 +169,13 @@ def test_dag_file_processor_sla_miss_callback_sent_notification(self): # Create dag with a start of 2 days ago, but an sla of 1 day # ago so we'll already have an sla_miss on the books test_start_date = days_ago(2) - dag = DAG( + dag, task = create_dummy_dag( dag_id='test_sla_miss', + task_id='dummy', sla_miss_callback=sla_callback, default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)}, ) - task = DummyOperator(task_id='dummy', dag=dag, owner='airflow') - # Create a TaskInstance for two days ago session.merge(TaskInstance(task=task, execution_date=test_start_date, state='success')) @@ -222,7 +196,7 @@ def test_dag_file_processor_sla_miss_callback_sent_notification(self): sla_callback.assert_not_called() - def test_dag_file_processor_sla_miss_callback_exception(self): + def test_dag_file_processor_sla_miss_callback_exception(self, create_dummy_dag): """ Test that the dag file processor gracefully logs an exception if there is a problem calling the sla_miss_callback @@ -232,14 +206,13 @@ def test_dag_file_processor_sla_miss_callback_exception(self): sla_callback = MagicMock(side_effect=RuntimeError('Could not call function')) test_start_date = days_ago(2) - dag = DAG( + dag, task = create_dummy_dag( dag_id='test_sla_miss', + task_id='dummy', sla_miss_callback=sla_callback, - default_args={'start_date': test_start_date}, + default_args={'start_date': test_start_date, 'sla': datetime.timedelta(hours=1)}, ) - task = DummyOperator(task_id='dummy', dag=dag, owner='airflow', sla=datetime.timedelta(hours=1)) - session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success')) # Create an SlaMiss where notification was sent, but email was not @@ -255,18 +228,18 @@ def test_dag_file_processor_sla_miss_callback_exception(self): ) @mock.patch('airflow.dag_processing.processor.send_email') - def test_dag_file_processor_only_collect_emails_from_sla_missed_tasks(self, mock_send_email): + def test_dag_file_processor_only_collect_emails_from_sla_missed_tasks( + self, mock_send_email, create_dummy_dag + ): session = settings.Session() test_start_date = days_ago(2) - dag = DAG( - dag_id='test_sla_miss', - default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)}, - ) - email1 = 'test1@test.com' - task = DummyOperator( - task_id='sla_missed', dag=dag, owner='airflow', email=email1, sla=datetime.timedelta(hours=1) + dag, task = create_dummy_dag( + dag_id='test_sla_miss', + task_id='sla_missed', + email=email1, + default_args={'start_date': test_start_date, 'sla': datetime.timedelta(hours=1)}, ) session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success')) @@ -288,7 +261,9 @@ def test_dag_file_processor_only_collect_emails_from_sla_missed_tasks(self, mock @mock.patch('airflow.dag_processing.processor.Stats.incr') @mock.patch("airflow.utils.email.send_email") - def test_dag_file_processor_sla_miss_email_exception(self, mock_send_email, mock_stats_incr): + def test_dag_file_processor_sla_miss_email_exception( + self, mock_send_email, mock_stats_incr, create_dummy_dag + ): """ Test that the dag file processor gracefully logs an exception if there is a problem sending an email @@ -299,14 +274,13 @@ def test_dag_file_processor_sla_miss_email_exception(self, mock_send_email, mock mock_send_email.side_effect = RuntimeError('Could not send an email') test_start_date = days_ago(2) - dag = DAG( + dag, task = create_dummy_dag( dag_id='test_sla_miss', - default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)}, - ) - - task = DummyOperator( - task_id='dummy', dag=dag, owner='airflow', email='test@test.com', sla=datetime.timedelta(hours=1) + task_id='dummy', + email='test@test.com', + default_args={'start_date': test_start_date, 'sla': datetime.timedelta(hours=1)}, ) + mock_stats_incr.reset_mock() session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success')) @@ -322,7 +296,7 @@ def test_dag_file_processor_sla_miss_email_exception(self, mock_send_email, mock ) mock_stats_incr.assert_called_once_with('sla_email_notification_failure') - def test_dag_file_processor_sla_miss_deleted_task(self): + def test_dag_file_processor_sla_miss_deleted_task(self, create_dummy_dag): """ Test that the dag file processor will not crash when trying to send sla miss notification for a deleted task @@ -330,13 +304,11 @@ def test_dag_file_processor_sla_miss_deleted_task(self): session = settings.Session() test_start_date = days_ago(2) - dag = DAG( + dag, task = create_dummy_dag( dag_id='test_sla_miss', - default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)}, - ) - - task = DummyOperator( - task_id='dummy', dag=dag, owner='airflow', email='test@test.com', sla=datetime.timedelta(hours=1) + task_id='dummy', + email='test@test.com', + default_args={'start_date': test_start_date, 'sla': datetime.timedelta(hours=1)}, ) session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success')) diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index d60d1b7cd3d2a..93f61386adcf8 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -100,39 +100,6 @@ def success_handler(self, context): self.task_state_in_callback = temp_instance.state -@pytest.fixture -def get_dummy_dag(dag_maker): - def create_dag( - dag_id='dag', - task_id='op1', - task_concurrency=16, - pool='default_pool', - executor_config={}, - trigger_rule='all_done', - on_success_callback=None, - on_execute_callback=None, - on_failure_callback=None, - on_retry_callback=None, - **kwargs, - ): - with dag_maker(dag_id, **kwargs) as dag: - op = DummyOperator( - task_id=task_id, - task_concurrency=task_concurrency, - executor_config=executor_config, - on_success_callback=on_success_callback, - on_execute_callback=on_execute_callback, - on_failure_callback=on_failure_callback, - on_retry_callback=on_retry_callback, - pool=pool, - trigger_rule=trigger_rule, - ) - dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) - return dag, op - - return create_dag - - class TestTaskInstance: @staticmethod def clean_db(): @@ -265,13 +232,13 @@ def test_set_dag(self, dag_maker): assert op.dag is dag assert op in dag.tasks - def test_infer_dag(self, get_dummy_dag): + def test_infer_dag(self, create_dummy_dag): op1 = DummyOperator(task_id='test_op_1') op2 = DummyOperator(task_id='test_op_2') - dag, op3 = get_dummy_dag(task_id='test_op_3') + dag, op3 = create_dummy_dag(task_id='test_op_3') - _, op4 = get_dummy_dag('dag2', task_id='test_op_4') + _, op4 = create_dummy_dag('dag2', task_id='test_op_4') # double check dags assert [i.has_dag() for i in [op1, op2, op3, op4]] == [False, False, True, True] @@ -304,10 +271,10 @@ def test_bitshift_compose_operators(self, dag_maker): assert op2 in op3.downstream_list @patch.object(DAG, 'get_concurrency_reached') - def test_requeue_over_dag_concurrency(self, mock_concurrency_reached, get_dummy_dag): + def test_requeue_over_dag_concurrency(self, mock_concurrency_reached, create_dummy_dag): mock_concurrency_reached.return_value = True - _, task = get_dummy_dag( + _, task = create_dummy_dag( dag_id='test_requeue_over_dag_concurrency', task_id='test_requeue_over_dag_concurrency_op', max_active_runs=1, @@ -322,8 +289,8 @@ def test_requeue_over_dag_concurrency(self, mock_concurrency_reached, get_dummy_ ti.run() assert ti.state == State.NONE - def test_requeue_over_task_concurrency(self, get_dummy_dag): - _, task = get_dummy_dag( + def test_requeue_over_task_concurrency(self, create_dummy_dag): + _, task = create_dummy_dag( dag_id='test_requeue_over_task_concurrency', task_id='test_requeue_over_task_concurrency_op', task_concurrency=0, @@ -339,8 +306,8 @@ def test_requeue_over_task_concurrency(self, get_dummy_dag): ti.run() assert ti.state == State.NONE - def test_requeue_over_pool_concurrency(self, get_dummy_dag): - _, task = get_dummy_dag( + def test_requeue_over_pool_concurrency(self, create_dummy_dag): + _, task = create_dummy_dag( dag_id='test_requeue_over_pool_concurrency', task_id='test_requeue_over_pool_concurrency_op', task_concurrency=0, @@ -391,13 +358,13 @@ def test_not_requeue_non_requeueable_task_instance(self, dag_maker): for (dep_patch, method_patch) in patch_dict.values(): dep_patch.stop() - def test_mark_non_runnable_task_as_success(self, get_dummy_dag): + def test_mark_non_runnable_task_as_success(self, create_dummy_dag): """ test that running task with mark_success param update task state as SUCCESS without running task despite it fails dependency checks. """ non_runnable_state = (set(State.task_states) - RUNNABLE_STATES - set(State.SUCCESS)).pop() - _, task = get_dummy_dag( + _, task = create_dummy_dag( dag_id='test_mark_non_runnable_task_as_success', task_id='test_mark_non_runnable_task_as_success_op', ) @@ -409,11 +376,11 @@ def test_mark_non_runnable_task_as_success(self, get_dummy_dag): ti.run(mark_success=True) assert ti.state == State.SUCCESS - def test_run_pooling_task(self, get_dummy_dag): + def test_run_pooling_task(self, create_dummy_dag): """ test that running a task in an existing pool update task state as SUCCESS. """ - _, task = get_dummy_dag( + _, task = create_dummy_dag( dag_id='test_run_pooling_task', task_id='test_run_pooling_task_op', pool='test_pool', @@ -444,11 +411,11 @@ def create_task_instance(): create_task_instance() @provide_session - def test_ti_updates_with_task(self, get_dummy_dag, session=None): + def test_ti_updates_with_task(self, create_dummy_dag, session=None): """ test that updating the executor_config propagates to the TaskInstance DB """ - dag, task = get_dummy_dag( + dag, task = create_dummy_dag( dag_id='test_run_pooling_task', task_id='test_run_pooling_task_op', executor_config={'foo': 'bar'}, @@ -472,13 +439,13 @@ def test_ti_updates_with_task(self, get_dummy_dag, session=None): assert {'bar': 'baz'} == tis[1].executor_config session.rollback() - def test_run_pooling_task_with_mark_success(self, get_dummy_dag): + def test_run_pooling_task_with_mark_success(self, create_dummy_dag): """ test that running task in an existing pool with mark_success param update task state as SUCCESS without running task despite it fails dependency checks. """ - _, task = get_dummy_dag( + _, task = create_dummy_dag( dag_id='test_run_pooling_task_with_mark_success', task_id='test_run_pooling_task_with_mark_success_op', ) @@ -944,9 +911,9 @@ def test_check_task_dependencies( flag_upstream_failed: bool, expect_state: State, expect_completed: bool, - get_dummy_dag, + create_dummy_dag, ): - dag, downstream = get_dummy_dag('test-dag', task_id='downstream', trigger_rule=trigger_rule) + dag, downstream = create_dummy_dag('test-dag', task_id='downstream', trigger_rule=trigger_rule) for i in range(5): task = DummyOperator(task_id=f'runme_{i}', dag=dag) task.set_downstream(downstream) @@ -967,8 +934,8 @@ def test_check_task_dependencies( assert completed == expect_completed assert ti.state == expect_state - def test_respects_prev_dagrun_dep(self, get_dummy_dag): - _, task = get_dummy_dag(dag_id='test_dag') + def test_respects_prev_dagrun_dep(self, create_dummy_dag): + _, task = create_dummy_dag(dag_id='test_dag') ti = TI(task, DEFAULT_DATE) failing_status = [TIDepStatus('test fail status name', False, 'test fail reason')] passing_status = [TIDepStatus('test pass status name', True, 'test passing reason')] @@ -991,8 +958,8 @@ def test_respects_prev_dagrun_dep(self, get_dummy_dag): (State.NONE, False), ], ) - def test_are_dependents_done(self, downstream_ti_state, expected_are_dependents_done, get_dummy_dag): - dag, task = get_dummy_dag() + def test_are_dependents_done(self, downstream_ti_state, expected_are_dependents_done, create_dummy_dag): + dag, task = create_dummy_dag() downstream_task = DummyOperator(task_id='downstream_task', dag=dag) task >> downstream_task @@ -1002,11 +969,11 @@ def test_are_dependents_done(self, downstream_ti_state, expected_are_dependents_ downstream_ti.set_state(downstream_ti_state) assert ti.are_dependents_done() == expected_are_dependents_done - def test_xcom_pull(self, get_dummy_dag): + def test_xcom_pull(self, create_dummy_dag): """ Test xcom_pull, using different filtering methods. """ - dag, task1 = get_dummy_dag( + dag, task1 = create_dummy_dag( dag_id='test_xcom', task_id='test_xcom_1', schedule_interval='@monthly', @@ -1040,14 +1007,14 @@ def test_xcom_pull(self, get_dummy_dag): result = ti1.xcom_pull(task_ids=['test_xcom_1', 'test_xcom_2'], key='foo') assert result == ['bar', 'baz'] - def test_xcom_pull_after_success(self, get_dummy_dag): + def test_xcom_pull_after_success(self, create_dummy_dag): """ tests xcom set/clear relative to a task in a 'success' rerun scenario """ key = 'xcom_key' value = 'xcom_value' - _, task = get_dummy_dag( + _, task = create_dummy_dag( dag_id='test_xcom', schedule_interval='@monthly', task_id='test_xcom', @@ -1072,7 +1039,7 @@ def test_xcom_pull_after_success(self, get_dummy_dag): ti.run(ignore_all_deps=True) assert ti.xcom_pull(task_ids='test_xcom', key=key) is None - def test_xcom_pull_different_execution_date(self, get_dummy_dag): + def test_xcom_pull_different_execution_date(self, create_dummy_dag): """ tests xcom fetch behavior with different execution dates, using both xcom_pull with "include_prior_dates" and without @@ -1080,7 +1047,7 @@ def test_xcom_pull_different_execution_date(self, get_dummy_dag): key = 'xcom_key' value = 'xcom_value' - dag, task = get_dummy_dag( + dag, task = create_dummy_dag( dag_id='test_xcom', schedule_interval='@monthly', task_id='test_xcom', @@ -1146,8 +1113,8 @@ def post_execute(self, context, result=None): with pytest.raises(TestError): ti.run() - def test_check_and_change_state_before_execution(self, get_dummy_dag): - _, task = get_dummy_dag(dag_id='test_check_and_change_state_before_execution') + def test_check_and_change_state_before_execution(self, create_dummy_dag): + _, task = create_dummy_dag(dag_id='test_check_and_change_state_before_execution') ti = TI(task=task, execution_date=DEFAULT_DATE) assert ti._try_number == 0 assert ti.check_and_change_state_before_execution() @@ -1155,18 +1122,18 @@ def test_check_and_change_state_before_execution(self, get_dummy_dag): assert ti.state == State.RUNNING assert ti._try_number == 1 - def test_check_and_change_state_before_execution_dep_not_met(self, get_dummy_dag): - dag, task = get_dummy_dag(dag_id='test_check_and_change_state_before_execution') + def test_check_and_change_state_before_execution_dep_not_met(self, create_dummy_dag): + dag, task = create_dummy_dag(dag_id='test_check_and_change_state_before_execution') task2 = DummyOperator(task_id='task2', dag=dag, start_date=DEFAULT_DATE) task >> task2 ti = TI(task=task2, execution_date=timezone.utcnow()) assert not ti.check_and_change_state_before_execution() - def test_try_number(self, get_dummy_dag): + def test_try_number(self, create_dummy_dag): """ Test the try_number accessor behaves in various running states """ - _, task = get_dummy_dag(dag_id='test_check_and_change_state_before_execution') + _, task = create_dummy_dag(dag_id='test_check_and_change_state_before_execution') ti = TI(task=task, execution_date=timezone.utcnow()) assert 1 == ti.try_number ti.try_number = 2 @@ -1175,11 +1142,11 @@ def test_try_number(self, get_dummy_dag): ti.state = State.SUCCESS assert 3 == ti.try_number - def test_get_num_running_task_instances(self, get_dummy_dag): + def test_get_num_running_task_instances(self, create_dummy_dag): session = settings.Session() - _, task = get_dummy_dag(dag_id='test_get_num_running_task_instances', task_id='task1') - _, task2 = get_dummy_dag(dag_id='test_get_num_running_task_instances_dummy', task_id='task2') + _, task = create_dummy_dag(dag_id='test_get_num_running_task_instances', task_id='task1') + _, task2 = create_dummy_dag(dag_id='test_get_num_running_task_instances_dummy', task_id='task2') ti1 = TI(task=task, execution_date=DEFAULT_DATE) ti2 = TI(task=task, execution_date=DEFAULT_DATE + datetime.timedelta(days=1)) ti3 = TI(task=task2, execution_date=DEFAULT_DATE) @@ -1207,8 +1174,8 @@ def test_get_num_running_task_instances(self, get_dummy_dag): # self.assertEqual(d['task_id'][0], 'op') # self.assertEqual(pendulum.parse(d['execution_date'][0]), now) - def test_log_url(self, get_dummy_dag): - _, task = get_dummy_dag('dag', task_id='op') + def test_log_url(self, create_dummy_dag): + _, task = create_dummy_dag('dag', task_id='op') ti = TI(task=task, execution_date=datetime.datetime(2018, 1, 1)) expected_url = ( @@ -1219,9 +1186,9 @@ def test_log_url(self, get_dummy_dag): ) assert ti.log_url == expected_url - def test_mark_success_url(self, get_dummy_dag): + def test_mark_success_url(self, create_dummy_dag): now = pendulum.now('Europe/Brussels') - _, task = get_dummy_dag('dag', task_id='op') + _, task = create_dummy_dag('dag', task_id='op') ti = TI(task=task, execution_date=now) query = urllib.parse.parse_qs( urllib.parse.urlparse(ti.mark_success_url).query, keep_blank_values=True, strict_parsing=True @@ -1324,9 +1291,9 @@ def test_set_duration_empty_dates(self): ti.set_duration() assert ti.duration is None - def test_success_callback_no_race_condition(self, get_dummy_dag): + def test_success_callback_no_race_condition(self, create_dummy_dag): callback_wrapper = CallbackWrapper() - _, task = get_dummy_dag( + _, task = create_dummy_dag( 'test_success_callback_no_race_condition', on_success_callback=callback_wrapper.success_handler, end_date=DEFAULT_DATE + datetime.timedelta(days=10), @@ -1472,8 +1439,8 @@ def test_get_previous_start_date_none(self, dag_maker): assert ti_2.get_previous_start_date() == ti_1.start_date assert ti_1.start_date is None - def test_pendulum_template_dates(self, get_dummy_dag): - dag, task = get_dummy_dag( + def test_pendulum_template_dates(self, create_dummy_dag): + dag, task = create_dummy_dag( dag_id='test_pendulum_template_dates', task_id='test_pendulum_template_dates_task', schedule_interval='0 12 * * *', @@ -1500,7 +1467,7 @@ def test_pendulum_template_dates(self, get_dummy_dag): ('{{ conn.a_connection.extra_dejson.extra__asana__workspace }}', 'extra1'), ], ) - def test_template_with_connection(self, content, expected_output, get_dummy_dag): + def test_template_with_connection(self, content, expected_output, create_dummy_dag): """ Test the availability of variables in templates """ @@ -1522,7 +1489,7 @@ def test_template_with_connection(self, content, expected_output, get_dummy_dag) session, ) - _, task = get_dummy_dag() + _, task = create_dummy_dag() ti = TI(task=task, execution_date=DEFAULT_DATE) context = ti.get_template_context() @@ -1538,24 +1505,24 @@ def test_template_with_connection(self, content, expected_output, get_dummy_dag) ('{{ var.value.get("missing_variable", "fallback") }}', 'fallback'), ], ) - def test_template_with_variable(self, content, expected_output, get_dummy_dag): + def test_template_with_variable(self, content, expected_output, create_dummy_dag): """ Test the availability of variables in templates """ Variable.set('a_variable', 'a test value') - _, task = get_dummy_dag() + _, task = create_dummy_dag() ti = TI(task=task, execution_date=DEFAULT_DATE) context = ti.get_template_context() result = task.render_template(content, context) assert result == expected_output - def test_template_with_variable_missing(self, get_dummy_dag): + def test_template_with_variable_missing(self, create_dummy_dag): """ Test the availability of variables in templates """ - _, task = get_dummy_dag() + _, task = create_dummy_dag() ti = TI(task=task, execution_date=DEFAULT_DATE) context = ti.get_template_context() @@ -1572,28 +1539,28 @@ def test_template_with_variable_missing(self, get_dummy_dag): ('{{ var.json.get("missing_variable", {"a": {"test": "fallback"}})["a"]["test"] }}', 'fallback'), ], ) - def test_template_with_json_variable(self, content, expected_output, get_dummy_dag): + def test_template_with_json_variable(self, content, expected_output, create_dummy_dag): """ Test the availability of variables in templates """ Variable.set('a_variable', {'a': {'test': 'value'}}, serialize_json=True) - _, task = get_dummy_dag() + _, task = create_dummy_dag() ti = TI(task=task, execution_date=DEFAULT_DATE) context = ti.get_template_context() result = task.render_template(content, context) assert result == expected_output - def test_template_with_json_variable_missing(self, get_dummy_dag): - _, task = get_dummy_dag() + def test_template_with_json_variable_missing(self, create_dummy_dag): + _, task = create_dummy_dag() ti = TI(task=task, execution_date=DEFAULT_DATE) context = ti.get_template_context() with pytest.raises(KeyError): task.render_template('{{ var.json.get("missing_variable") }}', context) - def test_execute_callback(self, get_dummy_dag): + def test_execute_callback(self, create_dummy_dag): called = False def on_execute_callable(context): @@ -1601,7 +1568,7 @@ def on_execute_callable(context): called = True assert context['dag_run'].dag_id == 'test_dagrun_execute_callback' - _, task = get_dummy_dag( + _, task = create_dummy_dag( 'test_execute_callback', on_execute_callback=on_execute_callable, end_date=DEFAULT_DATE + datetime.timedelta(days=10), @@ -1656,12 +1623,12 @@ def on_finish_callable(context): assert not completed ti.log.exception.assert_called_once_with(expected_message) - def test_handle_failure(self, get_dummy_dag): + def test_handle_failure(self, create_dummy_dag): start_date = timezone.datetime(2016, 6, 1) mock_on_failure_1 = mock.MagicMock() mock_on_retry_1 = mock.MagicMock() - dag, task1 = get_dummy_dag( + dag, task1 = create_dummy_dag( dag_id="test_handle_failure", schedule_interval=None, start_date=start_date, @@ -1777,8 +1744,8 @@ def test_echo_env_variables(self, dag_maker): assert ti.state == State.SUCCESS @patch.object(Stats, 'incr') - def test_task_stats(self, stats_mock, get_dummy_dag): - dag, op = get_dummy_dag( + def test_task_stats(self, stats_mock, create_dummy_dag): + dag, op = create_dummy_dag( 'test_task_start_end_stats', end_date=DEFAULT_DATE + datetime.timedelta(days=10), ) @@ -1942,8 +1909,8 @@ def test_get_rendered_k8s_spec(self, rtif_get_k8s_pod_yaml, dag_maker): render_k8s_pod_yaml.assert_called_once() - def test_set_state_up_for_retry(self, get_dummy_dag): - dag, op1 = get_dummy_dag('dag') + def test_set_state_up_for_retry(self, create_dummy_dag): + dag, op1 = create_dummy_dag('dag') ti = TI(task=op1, execution_date=timezone.utcnow(), state=State.RUNNING) start_date = timezone.utcnow() @@ -2072,8 +2039,8 @@ def teardown_method(self) -> None: (7, True), ], ) - def test_execute_queries_count(self, expected_query_count, mark_success, get_dummy_dag): - _, task = get_dummy_dag() + def test_execute_queries_count(self, expected_query_count, mark_success, create_dummy_dag): + _, task = create_dummy_dag() with create_session() as session: ti = TI(task=task, execution_date=datetime.datetime.now()) @@ -2091,8 +2058,8 @@ def test_execute_queries_count(self, expected_query_count, mark_success, get_dum with assert_queries_count(expected_query_count_based_on_db): ti._run_raw_task(mark_success=mark_success) - def test_execute_queries_count_store_serialized(self, get_dummy_dag): - _, task = get_dummy_dag() + def test_execute_queries_count_store_serialized(self, create_dummy_dag): + _, task = create_dummy_dag() with create_session() as session: ti = TI(task=task, execution_date=datetime.datetime.now()) ti.state = State.RUNNING @@ -2105,9 +2072,9 @@ def test_execute_queries_count_store_serialized(self, get_dummy_dag): with assert_queries_count(expected_query_count_based_on_db): ti._run_raw_task() - def test_operator_field_with_serialization(self, get_dummy_dag): + def test_operator_field_with_serialization(self, create_dummy_dag): - _, task = get_dummy_dag() + _, task = create_dummy_dag() assert task.task_type == 'DummyOperator' # Verify that ti.operator field renders correctly "without" Serialization