Skip to content

Commit

Permalink
Implement AIP-44 compliant DagRun.get_task_instances method (apache#4…
Browse files Browse the repository at this point in the history
  • Loading branch information
potiuk authored Jul 25, 2024
1 parent ba78d54 commit ae65820
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 6 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,7 @@ repos:
entry: ./scripts/ci/pre_commit/new_session_in_provide_session.py
pass_filenames: true
files: ^airflow/.+\.py$
exclude: ^airflow/serialization/pydantic/.*
- id: check-for-inclusive-language
language: pygrep
name: Check for language that we do not accept as community
Expand Down
1 change: 1 addition & 0 deletions airflow/api_internal/endpoints/rpc_api_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def initialize_method_map() -> dict[str, Callable]:
DagRun.get_previous_scheduled_dagrun,
DagRun.fetch_task_instance,
DagRun._get_log_template,
DagRun._get_task_instances,
RenderedTaskInstanceFields._update_runtime_evaluated_template_fields,
SerializedDagModel.get_serialized_dag,
SerializedDagModel.remove_deleted_dags,
Expand Down
6 changes: 5 additions & 1 deletion airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ def get_task_instances(
Redirect to DagRun.fetch_task_instances method.
Keep this method because it is widely used across the code.
"""
task_ids = self.dag.task_ids if self.dag and self.dag.partial else None
task_ids = DagRun._get_partial_task_ids(self.dag)
return DagRun.fetch_task_instances(
dag_id=self.dag_id, run_id=self.run_id, task_ids=task_ids, state=state, session=session
)
Expand Down Expand Up @@ -1678,6 +1678,10 @@ def get_log_filename_template(self, *, session: Session = NEW_SESSION) -> str:
)
return self.get_log_template(session=session).filename

@staticmethod
def _get_partial_task_ids(dag: DAG | None) -> list[str] | None:
return dag.task_ids if dag and dag.partial else None


class DagRunNote(Base):
"""For storage of arbitrary notes concerning the dagrun instance."""
Expand Down
17 changes: 12 additions & 5 deletions airflow/serialization/pydantic/dag_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
from datetime import datetime
from typing import TYPE_CHECKING, Iterable, List, Optional

from airflow.models.dagrun import DagRun
from airflow.serialization.pydantic.dag import PydanticDag
from airflow.serialization.pydantic.dataset import DatasetEventPydantic
from airflow.utils.pydantic import BaseModel as BaseModelPydantic, ConfigDict, is_pydantic_2_installed
from airflow.utils.session import NEW_SESSION, provide_session

if TYPE_CHECKING:
from sqlalchemy.orm import Session
Expand Down Expand Up @@ -62,18 +62,25 @@ class DagRunPydantic(BaseModelPydantic):
def logical_date(self) -> datetime:
return self.execution_date

@provide_session
def get_task_instances(
self,
state: Iterable[TaskInstanceState | None] | None = None,
session: Session = NEW_SESSION,
session=None,
) -> list[TI]:
"""
Return the task instances for this dag run.
TODO: make it works for AIP-44
Redirect to DagRun.fetch_task_instances method.
Keep this method because it is widely used across the code.
"""
raise NotImplementedError()
task_ids = DagRun._get_partial_task_ids(self.dag)
return DagRun.fetch_task_instances(
dag_id=self.dag_id,
run_id=self.run_id,
task_ids=task_ids,
state=state,
session=session,
)

def get_task_instance(
self,
Expand Down

0 comments on commit ae65820

Please sign in to comment.