Skip to content

Commit

Permalink
DebugExecutor use ti.run() instead of ti._run_raw_task (apache#24357)
Browse files Browse the repository at this point in the history
The DebugExecutor previously executed tasks by calling the "private"
ti._run_raw_task(...) method instead of ti.run(...). But the latter
contains the logic to increase task instance try_numbers when running,
thus tasks executed with the DebugExecutor were never getting their
try_numbers increased and for rescheduled tasks this led to off-by-one
errors (as the logic to reduce the try_number for the reschedule was
still working while the increase was not).
  • Loading branch information
o-nikolas authored Jun 13, 2022
1 parent 94257f4 commit da7b22b
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 24 deletions.
2 changes: 1 addition & 1 deletion airflow/executors/debug_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _run_task(self, ti: TaskInstance) -> bool:
key = ti.key
try:
params = self.tasks_params.pop(ti.key, {})
ti._run_raw_task(job_id=ti.job_id, **params)
ti.run(job_id=ti.job_id, **params)
self.change_state(key, State.SUCCESS)
return True
except Exception as e:
Expand Down
33 changes: 33 additions & 0 deletions tests/dags/test_sensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import datetime

from airflow import DAG
from airflow.decorators import task
from airflow.sensors.date_time import DateTimeSensor
from airflow.utils import timezone

with DAG(
dag_id='test_sensor', start_date=datetime.datetime(2022, 1, 1), catchup=False, schedule_interval='@once'
) as dag:

@task
def get_date():
return str(timezone.utcnow() + datetime.timedelta(seconds=3))

DateTimeSensor(task_id='dts', target_time=str(get_date()), poke_interval=1, mode='reschedule')
4 changes: 2 additions & 2 deletions tests/executors/test_debug_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_run_task(self, task_instance_mock):
succeeded = executor._run_task(task_instance_mock)

assert succeeded
task_instance_mock._run_raw_task.assert_called_once_with(job_id=job_id)
task_instance_mock.run.assert_called_once_with(job_id=job_id)

def test_queue_task_instance(self):
key = "ti_key"
Expand Down Expand Up @@ -100,7 +100,7 @@ def test_fail_fast(self, change_state_mock):
ti1 = MagicMock(key="t1")
ti2 = MagicMock(key="t2")

ti1._run_raw_task.side_effect = Exception
ti1.run.side_effect = Exception

executor.tasks_to_run = [ti1, ti2]

Expand Down
58 changes: 37 additions & 21 deletions tests/jobs/test_backfill_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -1313,10 +1313,18 @@ def test_update_counters(self, dag_maker, session):

ti_status = BackfillJob._DagRunTaskStatus()

# test for success
ti.set_state(State.SUCCESS, session)
ti_status.running[ti.key] = ti
job._update_counters(ti_status=ti_status, session=session)
# Test for success
# The in-memory task key in ti_status.running contains a try_number
# that is always one behind the DB. The _update_counters method however uses
# a reduced_key to handle this. To test this, we mark the task as running in-memory
# and then increase the try number as it would be before the raw task is executed.
# When updating the counters the reduced_key will be used which will match what's
# in the in-memory ti_status.running map. This is the same for skipped, failed
# and retry states.
ti_status.running[ti.key] = ti # Task is queued and marked as running
ti._try_number += 1 # Try number is increased during ti.run()
ti.set_state(State.SUCCESS, session) # Task finishes with success state
job._update_counters(ti_status=ti_status, session=session) # Update counters
assert len(ti_status.running) == 0
assert len(ti_status.succeeded) == 1
assert len(ti_status.skipped) == 0
Expand All @@ -1325,9 +1333,10 @@ def test_update_counters(self, dag_maker, session):

ti_status.succeeded.clear()

# test for skipped
ti.set_state(State.SKIPPED, session)
# Test for skipped
ti_status.running[ti.key] = ti
ti._try_number += 1
ti.set_state(State.SKIPPED, session)
job._update_counters(ti_status=ti_status, session=session)
assert len(ti_status.running) == 0
assert len(ti_status.succeeded) == 0
Expand All @@ -1337,9 +1346,10 @@ def test_update_counters(self, dag_maker, session):

ti_status.skipped.clear()

# test for failed
ti.set_state(State.FAILED, session)
# Test for failed
ti_status.running[ti.key] = ti
ti._try_number += 1
ti.set_state(State.FAILED, session)
job._update_counters(ti_status=ti_status, session=session)
assert len(ti_status.running) == 0
assert len(ti_status.succeeded) == 0
Expand All @@ -1349,9 +1359,10 @@ def test_update_counters(self, dag_maker, session):

ti_status.failed.clear()

# test for retry
ti.set_state(State.UP_FOR_RETRY, session)
# Test for retry
ti_status.running[ti.key] = ti
ti._try_number += 1
ti.set_state(State.UP_FOR_RETRY, session)
job._update_counters(ti_status=ti_status, session=session)
assert len(ti_status.running) == 0
assert len(ti_status.succeeded) == 0
Expand All @@ -1361,13 +1372,18 @@ def test_update_counters(self, dag_maker, session):

ti_status.to_run.clear()

# test for reschedule
# For rescheduled state, tests that reduced_key is not
# used by upping try_number.
ti._try_number = 2
ti.set_state(State.UP_FOR_RESCHEDULE, session)
assert ti.try_number == 3 # see ti.try_number property in taskinstance module
ti_status.running[ti.key] = ti
# Test for reschedule
# Logic in taskinstance reduces the try number for a task that's been
# rescheduled (which makes sense because it's the _same_ try, but it's
# just being rescheduled to a later time). This now makes the in-memory
# and DB representation of the task try_number the _same_, which is unlike
# the above cases. But this is okay because the reduced_key is NOT used for
# the rescheduled case in _update_counters, for this exact reason.
ti_status.running[ti.key] = ti # Task queued and marked as running
# Note: Both the increase and decrease are kept here for context
ti._try_number += 1 # Try number is increased during ti.run()
ti._try_number -= 1 # Task is being rescheduled, decrement try_number
ti.set_state(State.UP_FOR_RESCHEDULE, session) # Task finishes with reschedule state
job._update_counters(ti_status=ti_status, session=session)
assert len(ti_status.running) == 0
assert len(ti_status.succeeded) == 0
Expand Down Expand Up @@ -1584,10 +1600,10 @@ def test_backfill_has_job_id(self):

@pytest.mark.long_running
@pytest.mark.parametrize("executor_name", ["SequentialExecutor", "DebugExecutor"])
@pytest.mark.parametrize("dag_id", ["test_mapped_classic", "test_mapped_taskflow"])
def test_mapped_dag(self, dag_id, executor_name, session):
@pytest.mark.parametrize("dag_id", ["test_mapped_classic", "test_mapped_taskflow", "test_sensor"])
def test_backfilling_dags(self, dag_id, executor_name, session):
"""
End-to-end test of a simple mapped dag.
End-to-end test for backfilling dags with various executors.
We test with multiple executors as they have different "execution environments" -- for instance
DebugExecutor runs a lot more in the same process than other Executors.
Expand All @@ -1599,7 +1615,7 @@ def test_mapped_dag(self, dag_id, executor_name, session):
self.dagbag.process_file(str(TEST_DAGS_FOLDER / f'{dag_id}.py'))
dag = self.dagbag.get_dag(dag_id)

when = datetime.datetime(2022, 1, 1)
when = timezone.datetime(2022, 1, 1)

job = BackfillJob(
dag=dag,
Expand Down

0 comments on commit da7b22b

Please sign in to comment.