Skip to content

Commit

Permalink
Reuse _run_task_session in mapped render_template_fields (apache#33309)
Browse files Browse the repository at this point in the history
The `render_template_fields` method of mapped operator needs to use
database session object to render mapped fields, but it cannot
get the session passed by @provide_session decorator, because it is
used in derived classes and we cannot change the signature without
impacting those classes.

So far it was done by creating new session in mapped_operator, but
it has the drawback of creating an extra session while one is
already created (remnder_template_fields is always run in the
context of task run and it always has a session created already
in _run_raw_task). It also causes problems in our tests where
two opened database session accessed database at the same time
and it cases sqlite exception on concurrent access and mysql
error on running operations out of sync - likely when the same
object was modified in both sessions.

This PR changes the approach - rather than creating a new session
in the mapped_operator, we are retrieving the session from one
stored by the _run_raw_task. It is done by context manager and
adequate protection has been added to make sure that:

a) the call is made within the context manager
b) context manageer is never initialized twice in the same
   call stack

After this change, resources used by running task will be smaller,
and mapped tasks will not always open 2 DB sesions.

Fixes: apache#33178
  • Loading branch information
potiuk authored Sep 5, 2023
1 parent 20d8142 commit ef85c67
Show file tree
Hide file tree
Showing 8 changed files with 377 additions and 300 deletions.
4 changes: 3 additions & 1 deletion airflow/cli/commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
from airflow.utils.providers_configuration_loader import providers_configuration_loaded
from airflow.utils.session import NEW_SESSION, create_session, provide_session
from airflow.utils.state import DagRunState
from airflow.utils.task_instance_session import set_current_task_instance_session

if TYPE_CHECKING:
from sqlalchemy.orm.session import Session
Expand Down Expand Up @@ -649,7 +650,8 @@ def task_render(args, dag: DAG | None = None) -> None:
ti, _ = _get_ti(
task, args.map_index, exec_date_or_run_id=args.execution_date_or_run_id, create_if_necessary="memory"
)
ti.render_templates()
with create_session() as session, set_current_task_instance_session(session=session):
ti.render_templates()
for attr in task.template_fields:
print(
textwrap.dedent(
Expand Down
15 changes: 8 additions & 7 deletions airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

import attr

from airflow import settings
from airflow.compat.functools import cache
from airflow.exceptions import AirflowException, UnmappableOperator
from airflow.models.abstractoperator import (
Expand Down Expand Up @@ -54,6 +53,7 @@
from airflow.typing_compat import Literal
from airflow.utils.context import context_update_for_unmapped
from airflow.utils.helpers import is_container, prevent_duplicates
from airflow.utils.task_instance_session import get_current_task_instance_session
from airflow.utils.types import NOTSET
from airflow.utils.xcom import XCOM_RETURN_KEY

Expand Down Expand Up @@ -720,12 +720,13 @@ def render_template_fields(
if not jinja_env:
jinja_env = self.get_template_env()

# Ideally we'd like to pass in session as an argument to this function,
# but we can't easily change this function signature since operators
# could override this. We can't use @provide_session since it closes and
# expunges everything, which we don't want to do when we are so "deep"
# in the weeds here. We don't close this session for the same reason.
session = settings.Session()
# We retrieve the session here, stored by _run_raw_task in set_current_task_session
# context manager - we cannot pass the session via @provide_session because the signature
# of render_template_fields is defined by BaseOperator and there are already many subclasses
# overriding it, so changing the signature is not an option. However render_template_fields is
# always executed within "_run_raw_task" so we make sure that _run_raw_task uses the
# set_current_task_session context manager to store the session in the current task.
session = get_current_task_instance_session()

mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context, session)
unmapped_task = self.unmap(mapped_kwargs)
Expand Down
167 changes: 84 additions & 83 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@
)
from airflow.utils.state import DagRunState, JobState, State, TaskInstanceState
from airflow.utils.task_group import MappedTaskGroup
from airflow.utils.task_instance_session import set_current_task_instance_session
from airflow.utils.timeout import timeout
from airflow.utils.xcom import XCOM_RETURN_KEY

Expand Down Expand Up @@ -1511,98 +1512,98 @@ def _run_raw_task(
count=0,
tags={**self.stats_tags, "state": str(state)},
)
with set_current_task_instance_session(session=session):
self.task = self.task.prepare_for_execution()
context = self.get_template_context(ignore_param_exceptions=False)

self.task = self.task.prepare_for_execution()
context = self.get_template_context(ignore_param_exceptions=False)

try:
if not mark_success:
self._execute_task_with_callbacks(context, test_mode, session=session)
if not test_mode:
self.refresh_from_db(lock_for_update=True, session=session)
self.state = TaskInstanceState.SUCCESS
except TaskDeferred as defer:
# The task has signalled it wants to defer execution based on
# a trigger.
self._defer_task(defer=defer, session=session)
self.log.info(
"Pausing task as DEFERRED. dag_id=%s, task_id=%s, execution_date=%s, start_date=%s",
self.dag_id,
self.task_id,
self._date_or_empty("execution_date"),
self._date_or_empty("start_date"),
)
if not test_mode:
session.add(Log(self.state, self))
session.merge(self)
session.commit()
return TaskReturnCode.DEFERRED
except AirflowSkipException as e:
# Recording SKIP
# log only if exception has any arguments to prevent log flooding
if e.args:
self.log.info(e)
if not test_mode:
self.refresh_from_db(lock_for_update=True, session=session)
self.state = TaskInstanceState.SKIPPED
except AirflowRescheduleException as reschedule_exception:
self._handle_reschedule(actual_start_date, reschedule_exception, test_mode, session=session)
session.commit()
return None
except (AirflowFailException, AirflowSensorTimeout) as e:
# If AirflowFailException is raised, task should not retry.
# If a sensor in reschedule mode reaches timeout, task should not retry.
self.handle_failure(e, test_mode, context, force_fail=True, session=session)
session.commit()
raise
except AirflowException as e:
if not test_mode:
self.refresh_from_db(lock_for_update=True, session=session)
# for case when task is marked as success/failed externally
# or dagrun timed out and task is marked as skipped
# current behavior doesn't hit the callbacks
if self.state in State.finished:
self.clear_next_method_args()
session.merge(self)
try:
if not mark_success:
self._execute_task_with_callbacks(context, test_mode, session=session)
if not test_mode:
self.refresh_from_db(lock_for_update=True, session=session)
self.state = TaskInstanceState.SUCCESS
except TaskDeferred as defer:
# The task has signalled it wants to defer execution based on
# a trigger.
self._defer_task(defer=defer, session=session)
self.log.info(
"Pausing task as DEFERRED. dag_id=%s, task_id=%s, execution_date=%s, start_date=%s",
self.dag_id,
self.task_id,
self._date_or_empty("execution_date"),
self._date_or_empty("start_date"),
)
if not test_mode:
session.add(Log(self.state, self))
session.merge(self)
session.commit()
return TaskReturnCode.DEFERRED
except AirflowSkipException as e:
# Recording SKIP
# log only if exception has any arguments to prevent log flooding
if e.args:
self.log.info(e)
if not test_mode:
self.refresh_from_db(lock_for_update=True, session=session)
self.state = TaskInstanceState.SKIPPED
except AirflowRescheduleException as reschedule_exception:
self._handle_reschedule(actual_start_date, reschedule_exception, test_mode, session=session)
session.commit()
return None
else:
except (AirflowFailException, AirflowSensorTimeout) as e:
# If AirflowFailException is raised, task should not retry.
# If a sensor in reschedule mode reaches timeout, task should not retry.
self.handle_failure(e, test_mode, context, force_fail=True, session=session)
session.commit()
raise
except AirflowException as e:
if not test_mode:
self.refresh_from_db(lock_for_update=True, session=session)
# for case when task is marked as success/failed externally
# or dagrun timed out and task is marked as skipped
# current behavior doesn't hit the callbacks
if self.state in State.finished:
self.clear_next_method_args()
session.merge(self)
session.commit()
return None
else:
self.handle_failure(e, test_mode, context, session=session)
session.commit()
raise
except (Exception, KeyboardInterrupt) as e:
self.handle_failure(e, test_mode, context, session=session)
session.commit()
raise
except (Exception, KeyboardInterrupt) as e:
self.handle_failure(e, test_mode, context, session=session)
session.commit()
raise
finally:
Stats.incr(f"ti.finish.{self.dag_id}.{self.task_id}.{self.state}", tags=self.stats_tags)
# Same metric with tagging
Stats.incr("ti.finish", tags={**self.stats_tags, "state": str(self.state)})

# Recording SKIPPED or SUCCESS
self.clear_next_method_args()
self.end_date = timezone.utcnow()
self._log_state()
self.set_duration()
finally:
Stats.incr(f"ti.finish.{self.dag_id}.{self.task_id}.{self.state}", tags=self.stats_tags)
# Same metric with tagging
Stats.incr("ti.finish", tags={**self.stats_tags, "state": str(self.state)})

# Recording SKIPPED or SUCCESS
self.clear_next_method_args()
self.end_date = timezone.utcnow()
self._log_state()
self.set_duration()

# run on_success_callback before db committing
# otherwise, the LocalTaskJob sees the state is changed to `success`,
# but the task_runner is still running, LocalTaskJob then treats the state is set externally!
self._run_finished_callback(self.task.on_success_callback, context, "on_success")

# run on_success_callback before db committing
# otherwise, the LocalTaskJob sees the state is changed to `success`,
# but the task_runner is still running, LocalTaskJob then treats the state is set externally!
self._run_finished_callback(self.task.on_success_callback, context, "on_success")

if not test_mode:
session.add(Log(self.state, self))
session.merge(self).task = self.task
if self.state == TaskInstanceState.SUCCESS:
self._register_dataset_changes(session=session)
if not test_mode:
session.add(Log(self.state, self))
session.merge(self).task = self.task
if self.state == TaskInstanceState.SUCCESS:
self._register_dataset_changes(session=session)

session.commit()
if self.state == TaskInstanceState.SUCCESS:
get_listener_manager().hook.on_task_instance_success(
previous_state=TaskInstanceState.RUNNING, task_instance=self, session=session
)
session.commit()
if self.state == TaskInstanceState.SUCCESS:
get_listener_manager().hook.on_task_instance_success(
previous_state=TaskInstanceState.RUNNING, task_instance=self, session=session
)

return None
return None

def _register_dataset_changes(self, *, session: Session) -> None:
for obj in self.task.outlets or []:
Expand Down
60 changes: 60 additions & 0 deletions airflow/utils/task_instance_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import contextlib
import logging
import traceback
from typing import TYPE_CHECKING

from airflow.utils.session import create_session

if TYPE_CHECKING:
from sqlalchemy.orm import Session

__current_task_instance_session: Session | None = None

log = logging.getLogger(__name__)


def get_current_task_instance_session() -> Session:
global __current_task_instance_session
if not __current_task_instance_session:
log.warning("No task session set for this task. Continuing but this likely causes a resource leak.")
log.warning("Please report this and stacktrace below to https://github.com/apache/airflow/issues")
for filename, line_number, name, line in traceback.extract_stack():
log.warning('File: "%s", %s , in %s', filename, line_number, name)
if line:
log.warning(" %s", line.strip())
__current_task_instance_session = create_session()
return __current_task_instance_session


@contextlib.contextmanager
def set_current_task_instance_session(session: Session):
global __current_task_instance_session
if __current_task_instance_session:
raise RuntimeError(
"Session already set for this task. "
"You can only have one 'set_current_task_session' context manager active at a time."
)
__current_task_instance_session = session
try:
yield
finally:
__current_task_instance_session = None
52 changes: 27 additions & 25 deletions tests/decorators/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from airflow.utils import timezone
from airflow.utils.state import State
from airflow.utils.task_group import TaskGroup
from airflow.utils.task_instance_session import set_current_task_instance_session
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.types import DagRunType
from airflow.utils.xcom import XCOM_RETURN_KEY
Expand Down Expand Up @@ -747,36 +748,37 @@ def test_mapped_render_template_fields(dag_maker, session):
def fn(arg1, arg2):
...

with dag_maker(session=session):
task1 = BaseOperator(task_id="op1")
mapped = fn.partial(arg2="{{ ti.task_id }}").expand(arg1=task1.output)
with set_current_task_instance_session(session=session):
with dag_maker(session=session):
task1 = BaseOperator(task_id="op1")
mapped = fn.partial(arg2="{{ ti.task_id }}").expand(arg1=task1.output)

dr = dag_maker.create_dagrun()
ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session)

ti.xcom_push(key=XCOM_RETURN_KEY, value=["{{ ds }}"], session=session)

session.add(
TaskMap(
dag_id=dr.dag_id,
task_id=task1.task_id,
run_id=dr.run_id,
map_index=-1,
length=1,
keys=None,
dr = dag_maker.create_dagrun()
ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session)

ti.xcom_push(key=XCOM_RETURN_KEY, value=["{{ ds }}"], session=session)

session.add(
TaskMap(
dag_id=dr.dag_id,
task_id=task1.task_id,
run_id=dr.run_id,
map_index=-1,
length=1,
keys=None,
)
)
)
session.flush()
session.flush()

mapped_ti: TaskInstance = dr.get_task_instance(mapped.operator.task_id, session=session)
mapped_ti.map_index = 0
mapped_ti: TaskInstance = dr.get_task_instance(mapped.operator.task_id, session=session)
mapped_ti.map_index = 0

assert isinstance(mapped_ti.task, MappedOperator)
mapped.operator.render_template_fields(context=mapped_ti.get_template_context(session=session))
assert isinstance(mapped_ti.task, BaseOperator)
assert isinstance(mapped_ti.task, MappedOperator)
mapped.operator.render_template_fields(context=mapped_ti.get_template_context(session=session))
assert isinstance(mapped_ti.task, BaseOperator)

assert mapped_ti.task.op_kwargs["arg1"] == "{{ ds }}"
assert mapped_ti.task.op_kwargs["arg2"] == "fn"
assert mapped_ti.task.op_kwargs["arg1"] == "{{ ds }}"
assert mapped_ti.task.op_kwargs["arg2"] == "fn"


def test_task_decorator_has_wrapped_attr():
Expand Down
Loading

0 comments on commit ef85c67

Please sign in to comment.