Skip to content

Commit

Permalink
AIP-66 Refactor DagRun to DagVersion association
Browse files Browse the repository at this point in the history
This removes DagRun.dag_version association and replaces it with dag_versions
property on DagRun that collects all the dag_version_ids associated with the
task instances of the DagRun.

closes: apache#46565
  • Loading branch information
ephraimbuddy committed Feb 12, 2025
1 parent ddcb728 commit b88bc1d
Show file tree
Hide file tree
Showing 9 changed files with 175 additions and 119 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
---
default_stages: [pre-commit, pre-push]
default_language_version:
python: python3
python: python3.12
node: 22.2.0
minimum_pre_commit_version: '3.2.0'
exclude: ^.*/.*_vendor/
Expand Down
17 changes: 12 additions & 5 deletions airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1660,7 +1660,9 @@ def _schedule_dag_run(
self.log.error("Logical date is in future: %s", dag_run.logical_date)
return callback

if not self._verify_integrity_if_dag_changed(dag_run=dag_run, session=session):
if not dag_run.bundle_version and not self._verify_integrity_if_dag_changed(
dag_run=dag_run, session=session
):
self.log.warning(
"The DAG disappeared before verifying integrity: %s. Skipping.", dag_run.dag_id
)
Expand Down Expand Up @@ -1695,7 +1697,8 @@ def _verify_integrity_if_dag_changed(self, dag_run: DagRun, session: Session) ->
latest_dag_version = DagVersion.get_latest_version(dag_run.dag_id, session=session)
if TYPE_CHECKING:
assert latest_dag_version
if dag_run.dag_version_id == latest_dag_version.id:
dag_version_ids = dag_run.dag_versions(session=session)
if latest_dag_version.id in dag_version_ids:
self.log.debug("DAG %s not changed structure, skipping dagrun.verify_integrity", dag_run.dag_id)
return True

Expand All @@ -1704,10 +1707,14 @@ def _verify_integrity_if_dag_changed(self, dag_run: DagRun, session: Session) ->
if not dag_run.dag:
return False

dag_run.dag_version = latest_dag_version

# Verify integrity also takes care of session.flush
dag_run.verify_integrity(session=session)
dag_run.verify_integrity(dag_version_id=latest_dag_version.id, session=session)
# Select all TIs in State.unfinished and update the dag_version_id
session.execute(
update(TI)
.where(TI.run_id == dag_run.run_id, TI.dag_id == dag_run.dag_id, TI.state.in_(State.unfinished))
.values(dag_version_id=latest_dag_version.id)
)
return True

def _send_dag_callbacks_to_processor(self, dag: DAG, callback: DagCallbackRequest | None = None) -> None:
Expand Down
10 changes: 0 additions & 10 deletions airflow/migrations/versions/0047_3_0_0_add_dag_versioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,14 +154,6 @@ def upgrade():
batch_op.add_column(sa.Column("dag_version_id", UUIDType(binary=False)))

with op.batch_alter_table("dag_run", schema=None) as batch_op:
batch_op.add_column(sa.Column("dag_version_id", UUIDType(binary=False)))
batch_op.create_foreign_key(
batch_op.f("dag_run_dag_version_id_fkey"),
"dag_version",
["dag_version_id"],
["id"],
ondelete="CASCADE",
)
batch_op.drop_column("dag_hash")


Expand Down Expand Up @@ -199,7 +191,5 @@ def downgrade():

with op.batch_alter_table("dag_run", schema=None) as batch_op:
batch_op.add_column(sa.Column("dag_hash", sa.String(length=32), autoincrement=False, nullable=True))
batch_op.drop_constraint(batch_op.f("dag_run_dag_version_id_fkey"), type_="foreignkey")
batch_op.drop_column("dag_version_id")

op.drop_table("dag_version")
7 changes: 5 additions & 2 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,6 @@ def _create_orm_dagrun(
conf=conf,
state=state,
run_type=run_type,
dag_version=dag_version,
creating_job_id=creating_job_id,
data_interval=data_interval,
triggered_by=triggered_by,
Expand All @@ -287,7 +286,11 @@ def _create_orm_dagrun(
run.dag = dag
# create the associated task instances
# state is None at the moment of creation
run.verify_integrity(session=session)
if not dag_version:
dag_version = DagVersion.get_latest_version(dag.dag_id, session=session)
if not dag_version:
raise AirflowException(f"Could not find a version for DAG {dag.dag_id}")
run.verify_integrity(dag_version_id=dag_version.id, session=session)
return run


Expand Down
1 change: 0 additions & 1 deletion airflow/models/dag_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ class DagVersion(Base):
cascade="all, delete, delete-orphan",
cascade_backrefs=False,
)
dag_runs = relationship("DagRun", back_populates="dag_version", cascade="all, delete, delete-orphan")
task_instances = relationship("TaskInstance", back_populates="dag_version")
created_at = Column(UtcDateTime, nullable=False, default=timezone.utcnow)

Expand Down
54 changes: 34 additions & 20 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
from sqlalchemy.orm import declared_attr, joinedload, relationship, synonym, validates
from sqlalchemy.sql.expression import case, false, select, true
from sqlalchemy.sql.functions import coalesce
from sqlalchemy_utils import UUIDType

from airflow.callbacks.callback_requests import DagCallbackRequest
from airflow.configuration import conf as airflow_conf
Expand All @@ -61,8 +60,8 @@
from airflow.models.abstractoperator import NotMapped
from airflow.models.backfill import Backfill
from airflow.models.base import Base, StringID
from airflow.models.dag_version import DagVersion
from airflow.models.taskinstance import TaskInstance as TI
from airflow.models.taskinstancehistory import TaskInstanceHistory as TIH
from airflow.models.tasklog import LogTemplate
from airflow.models.taskmap import TaskMap
from airflow.stats import Stats
Expand All @@ -84,9 +83,11 @@
from datetime import datetime

from sqlalchemy.orm import Query, Session
from sqlalchemy_utils import UUIDType

from airflow.models.baseoperator import BaseOperator
from airflow.models.dag import DAG
from airflow.models.dag_version import DagVersion
from airflow.models.operator import Operator
from airflow.typing_compat import Literal
from airflow.utils.types import ArgNotSet
Expand Down Expand Up @@ -172,8 +173,6 @@ class DagRun(Base, LoggingMixin):
It's possible this could change if e.g. the dag run is cleared to be rerun, or perhaps re-backfilled.
"""
dag_version_id = Column(UUIDType(binary=False), ForeignKey("dag_version.id", ondelete="CASCADE"))
dag_version = relationship("DagVersion", back_populates="dag_runs")
bundle_version = Column(StringID())

# Remove this `if` after upgrading Sphinx-AutoAPI
Expand Down Expand Up @@ -250,7 +249,6 @@ def __init__(
data_interval: tuple[datetime, datetime] | None = None,
triggered_by: DagRunTriggeredByType | None = None,
backfill_id: int | None = None,
dag_version: DagVersion | None = None,
bundle_version: str | None = None,
):
if data_interval is None:
Expand Down Expand Up @@ -278,7 +276,6 @@ def __init__(
self.backfill_id = backfill_id
self.clear_number = 0
self.triggered_by = triggered_by
self.dag_version = dag_version
super().__init__()

def __repr__(self):
Expand All @@ -300,6 +297,21 @@ def validate_run_id(self, key: str, run_id: str) -> str | None:
f"The run_id provided '{run_id}' does not match regex pattern '{regex}' or '{RUN_ID_REGEX}'"
)

@provide_session
def dag_versions(self, session: Session = NEW_SESSION) -> list[DagVersion]:
"""
Return the DAG versions associated with the TIs of this DagRun.
:param session: database session
"""
tis = session.scalars(
select(TI.dag_version_id).where(TI.run_id == self.run_id, TI.dag_id == self.dag_id).distinct()
).all()
tih = session.scalars(
select(TIH.dag_version_id).where(TIH.run_id == self.run_id, TIH.dag_id == self.dag_id).distinct()
).all()
return list(tis + tih)

@property
def stats_tags(self) -> dict[str, str]:
return prune_dict({"dag_id": self.dag_id, "run_type": self.run_type})
Expand Down Expand Up @@ -935,7 +947,7 @@ def recalculate(self) -> _UnfinishedStates:
filepath=self.dag_model.relative_fileloc,
dag_id=self.dag_id,
run_id=self.run_id,
bundle_name=self.dag_version.bundle_name,
bundle_name=self.dag_model.bundle_name,
bundle_version=self.bundle_version,
is_failure_callback=True,
msg="task_failure",
Expand Down Expand Up @@ -964,7 +976,7 @@ def recalculate(self) -> _UnfinishedStates:
filepath=self.dag_model.relative_fileloc,
dag_id=self.dag_id,
run_id=self.run_id,
bundle_name=self.dag_version.bundle_name,
bundle_name=self.dag_model.bundle_name,
bundle_version=self.bundle_version,
is_failure_callback=False,
msg="success",
Expand All @@ -983,7 +995,7 @@ def recalculate(self) -> _UnfinishedStates:
filepath=self.dag_model.relative_fileloc,
dag_id=self.dag_id,
run_id=self.run_id,
bundle_name=self.dag_version.bundle_name,
bundle_name=self.dag_model.bundle_name,
bundle_version=self.bundle_version,
is_failure_callback=True,
msg="all_tasks_deadlocked",
Expand All @@ -998,9 +1010,8 @@ def recalculate(self) -> _UnfinishedStates:
"DagRun Finished: dag_id=%s, logical_date=%s, run_id=%s, "
"run_start_date=%s, run_end_date=%s, run_duration=%s, "
"state=%s, external_trigger=%s, run_type=%s, "
"data_interval_start=%s, data_interval_end=%s, dag_version_name=%s"
"data_interval_start=%s, data_interval_end=%s,"
)
dagv = session.scalar(select(DagVersion).where(DagVersion.id == self.dag_version_id))
self.log.info(
msg,
self.dag_id,
Expand All @@ -1018,10 +1029,9 @@ def recalculate(self) -> _UnfinishedStates:
self.run_type,
self.data_interval_start,
self.data_interval_end,
dagv.version if dagv else None,
)

self._trace_dagrun(dagv)
self._trace_dagrun()

session.flush()

Expand All @@ -1033,7 +1043,7 @@ def recalculate(self) -> _UnfinishedStates:

return schedulable_tis, callback

def _trace_dagrun(self, dagv) -> None:
def _trace_dagrun(self) -> None:
with Trace.start_span_from_dagrun(dagrun=self) as span:
if self._state == DagRunState.FAILED:
span.set_attribute("error", True)
Expand All @@ -1053,7 +1063,6 @@ def _trace_dagrun(self, dagv) -> None:
"run_type": str(self.run_type),
"data_interval_start": str(self.data_interval_start),
"data_interval_end": str(self.data_interval_end),
"dag_version": str(dagv.version if dagv else None),
"conf": str(self.conf),
}
if span.is_recording():
Expand Down Expand Up @@ -1308,13 +1317,13 @@ def _emit_duration_stats_for_finished_state(self):
Stats.timing(f"dagrun.duration.{self.state}", **timer_params)

@provide_session
def verify_integrity(self, *, session: Session = NEW_SESSION) -> None:
def verify_integrity(self, *, dag_version_id: UUIDType, session: Session = NEW_SESSION) -> None:
"""
Verify the DagRun by checking for removed tasks or tasks that are not in the database yet.
It will set state to removed or add the task if required.
:missing_indexes: A dictionary of task vs indexes that are missing.
:param dag_version_id: The DAG version ID
:param session: Sqlalchemy ORM Session
"""
from airflow.settings import task_instance_mutation_hook
Expand All @@ -1340,7 +1349,9 @@ def task_filter(task: Operator) -> bool:
)

created_counts: dict[str, int] = defaultdict(int)
task_creator = self._get_task_creator(created_counts, task_instance_mutation_hook, hook_is_noop)
task_creator = self._get_task_creator(
created_counts, task_instance_mutation_hook, hook_is_noop, dag_version_id
)

# Create the missing tasks, including mapped tasks
tasks_to_create = (task for task in dag.task_dict.values() if task_filter(task))
Expand Down Expand Up @@ -1436,6 +1447,7 @@ def _get_task_creator(
created_counts: dict[str, int],
ti_mutation_hook: Callable,
hook_is_noop: Literal[True],
dag_version_id: UUIDType,
) -> Callable[[Operator, Iterable[int]], Iterator[dict[str, Any]]]: ...

@overload
Expand All @@ -1444,13 +1456,15 @@ def _get_task_creator(
created_counts: dict[str, int],
ti_mutation_hook: Callable,
hook_is_noop: Literal[False],
dag_version_id: UUIDType,
) -> Callable[[Operator, Iterable[int]], Iterator[TI]]: ...

def _get_task_creator(
self,
created_counts: dict[str, int],
ti_mutation_hook: Callable,
hook_is_noop: Literal[True, False],
dag_version_id: UUIDType,
) -> Callable[[Operator, Iterable[int]], Iterator[dict[str, Any]] | Iterator[TI]]:
"""
Get the task creator function.
Expand All @@ -1468,7 +1482,7 @@ def create_ti_mapping(task: Operator, indexes: Iterable[int]) -> Iterator[dict[s
created_counts[task.task_type] += 1
for map_index in indexes:
yield TI.insert_mapping(
self.run_id, task, map_index=map_index, dag_version_id=self.dag_version_id
self.run_id, task, map_index=map_index, dag_version_id=dag_version_id
)

creator = create_ti_mapping
Expand All @@ -1477,7 +1491,7 @@ def create_ti_mapping(task: Operator, indexes: Iterable[int]) -> Iterator[dict[s

def create_ti(task: Operator, indexes: Iterable[int]) -> Iterator[TI]:
for map_index in indexes:
ti = TI(task, run_id=self.run_id, map_index=map_index, dag_version_id=self.dag_version_id)
ti = TI(task, run_id=self.run_id, map_index=map_index, dag_version_id=dag_version_id)
ti_mutation_hook(ti)
created_counts[ti.operator] += 1
yield ti
Expand Down
Loading

0 comments on commit b88bc1d

Please sign in to comment.