Skip to content

Commit

Permalink
Add triggers for ExternalTask (apache#29313)
Browse files Browse the repository at this point in the history
Contributes back two of the core Triggers from https://github.com/astronomer/astronomer-providers so that it can be used to create an operator /sensor or used within taskflow API
  • Loading branch information
kaxil authored Feb 2, 2023
1 parent cd5a92c commit 6ec97dc
Show file tree
Hide file tree
Showing 2 changed files with 325 additions and 0 deletions.
163 changes: 163 additions & 0 deletions airflow/triggers/external_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import asyncio
import datetime
import typing
from typing import Any

from asgiref.sync import sync_to_async
from sqlalchemy import func
from sqlalchemy.orm import Session

from airflow.models import DagRun, TaskInstance
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils.session import provide_session


class TaskStateTrigger(BaseTrigger):
"""
Waits asynchronously for a task in a different DAG to complete for a
specific logical date.
:param dag_id: The dag_id that contains the task you want to wait for
:param task_id: The task_id that contains the task you want to
wait for. If ``None`` (default value) the sensor waits for the DAG
:param states: allowed states, default is ``['success']``
:param execution_dates:
:param poll_interval: The time interval in seconds to check the state.
The default value is 5 sec.
"""

def __init__(
self,
dag_id: str,
task_id: str,
states: list[str],
execution_dates: list[datetime.datetime],
poll_interval: float = 5.0,
):
super().__init__()
self.dag_id = dag_id
self.task_id = task_id
self.states = states
self.execution_dates = execution_dates
self.poll_interval = poll_interval

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes TaskStateTrigger arguments and classpath."""
return (
"airflow.triggers.external_task.TaskStateTrigger",
{
"dag_id": self.dag_id,
"task_id": self.task_id,
"states": self.states,
"execution_dates": self.execution_dates,
"poll_interval": self.poll_interval,
},
)

async def run(self) -> typing.AsyncIterator["TriggerEvent"]:
"""
Checks periodically in the database to see if the task exists, and has
hit one of the states yet, or not.
"""
while True:
num_tasks = await self.count_tasks()
if num_tasks == len(self.execution_dates):
yield TriggerEvent(True)
await asyncio.sleep(self.poll_interval)

@sync_to_async
@provide_session
def count_tasks(self, session: Session) -> int | None:
"""Count how many task instances in the database match our criteria."""
count = (
session.query(func.count("*")) # .count() is inefficient
.filter(
TaskInstance.dag_id == self.dag_id,
TaskInstance.task_id == self.task_id,
TaskInstance.state.in_(self.states),
TaskInstance.execution_date.in_(self.execution_dates),
)
.scalar()
)
return typing.cast(int, count)


class DagStateTrigger(BaseTrigger):
"""
Waits asynchronously for a DAG to complete for a specific logical date.
:param dag_id: The dag_id that contains the task you want to wait for
:param states: allowed states, default is ``['success']``
:param execution_dates: The logical date at which DAG run.
:param poll_interval: The time interval in seconds to check the state.
The default value is 5.0 sec.
"""

def __init__(
self,
dag_id: str,
states: list[str],
execution_dates: list[datetime.datetime],
poll_interval: float = 5.0,
):
super().__init__()
self.dag_id = dag_id
self.states = states
self.execution_dates = execution_dates
self.poll_interval = poll_interval

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes DagStateTrigger arguments and classpath."""
return (
"airflow.triggers.external_task.DagStateTrigger",
{
"dag_id": self.dag_id,
"states": self.states,
"execution_dates": self.execution_dates,
"poll_interval": self.poll_interval,
},
)

async def run(self) -> typing.AsyncIterator["TriggerEvent"]:
"""
Checks periodically in the database to see if the dag run exists, and has
hit one of the states yet, or not.
"""
while True:
num_dags = await self.count_dags()
if num_dags == len(self.execution_dates):
yield TriggerEvent(True)
await asyncio.sleep(self.poll_interval)

@sync_to_async
@provide_session
def count_dags(self, session: Session) -> int | None:
"""Count how many dag runs in the database match our criteria."""
count = (
session.query(func.count("*")) # .count() is inefficient
.filter(
DagRun.dag_id == self.dag_id,
DagRun.state.in_(self.states),
DagRun.execution_date.in_(self.execution_dates),
)
.scalar()
)
return typing.cast(int, count)
162 changes: 162 additions & 0 deletions tests/triggers/test_external_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import asyncio

import pytest

from airflow import DAG
from airflow.models import DagRun, TaskInstance
from airflow.operators.empty import EmptyOperator
from airflow.triggers.external_task import DagStateTrigger, TaskStateTrigger
from airflow.utils import timezone
from airflow.utils.state import DagRunState, TaskInstanceState


class TestTaskStateTrigger:
DAG_ID = "external_task"
TASK_ID = "external_task_op"
RUN_ID = "external_task_run_id"
STATES = ["success", "fail"]

@pytest.mark.asyncio
async def test_task_state_trigger(self, session):
"""
Asserts that the TaskStateTrigger only goes off on or after a TaskInstance
reaches an allowed state (i.e. SUCCESS).
"""
dag = DAG(self.DAG_ID, start_date=timezone.datetime(2022, 1, 1))
dag_run = DagRun(
dag_id=dag.dag_id,
run_type="manual",
execution_date=timezone.datetime(2022, 1, 1),
run_id=self.RUN_ID,
)
session.add(dag_run)
session.commit()

external_task = EmptyOperator(task_id=self.TASK_ID, dag=dag)
instance = TaskInstance(external_task, timezone.datetime(2022, 1, 1))
session.add(instance)
session.commit()

trigger = TaskStateTrigger(
dag_id=dag.dag_id,
task_id=instance.task_id,
states=self.STATES,
execution_dates=[timezone.datetime(2022, 1, 1)],
poll_interval=0.2,
)

task = asyncio.create_task(trigger.run().__anext__())
await asyncio.sleep(0.5)

# It should not have produced a result
assert task.done() is False

# Progress the task to a "success" state so that run() yields a TriggerEvent
instance.state = TaskInstanceState.SUCCESS
session.commit()
await asyncio.sleep(0.5)
assert task.done() is True

# Prevents error when task is destroyed while in "pending" state
asyncio.get_event_loop().stop()

def test_serialization(self):
"""
Asserts that the TaskStateTrigger correctly serializes its arguments
and classpath.
"""
trigger = TaskStateTrigger(
dag_id=self.DAG_ID,
task_id=self.TASK_ID,
states=self.STATES,
execution_dates=[timezone.datetime(2022, 1, 1)],
poll_interval=5,
)
classpath, kwargs = trigger.serialize()
assert classpath == "airflow.triggers.external_task.TaskStateTrigger"
assert kwargs == {
"dag_id": self.DAG_ID,
"task_id": self.TASK_ID,
"states": self.STATES,
"execution_dates": [timezone.datetime(2022, 1, 1)],
"poll_interval": 5,
}


class TestDagStateTrigger:
DAG_ID = "test_dag_state_trigger"
RUN_ID = "external_task_run_id"
STATES = ["success", "fail"]

@pytest.mark.asyncio
async def test_dag_state_trigger(self, session):
"""
Assert that the DagStateTrigger only goes off on or after a DagRun
reaches an allowed state (i.e. SUCCESS).
"""
dag = DAG(self.DAG_ID, start_date=timezone.datetime(2022, 1, 1))
dag_run = DagRun(
dag_id=dag.dag_id,
run_type="manual",
execution_date=timezone.datetime(2022, 1, 1),
run_id=self.RUN_ID,
)
session.add(dag_run)
session.commit()

trigger = DagStateTrigger(
dag_id=dag.dag_id,
states=self.STATES,
execution_dates=[timezone.datetime(2022, 1, 1)],
poll_interval=0.2,
)

task = asyncio.create_task(trigger.run().__anext__())
await asyncio.sleep(0.5)

# It should not have produced a result
assert task.done() is False

# Progress the dag to a "success" state so that yields a TriggerEvent
dag_run.state = DagRunState.SUCCESS
session.commit()
await asyncio.sleep(0.5)
assert task.done() is True

# Prevents error when task is destroyed while in "pending" state
asyncio.get_event_loop().stop()

def test_serialization(self):
"""Asserts that the DagStateTrigger correctly serializes its arguments and classpath."""
trigger = DagStateTrigger(
dag_id=self.DAG_ID,
states=self.STATES,
execution_dates=[timezone.datetime(2022, 1, 1)],
poll_interval=5,
)
classpath, kwargs = trigger.serialize()
assert classpath == "airflow.triggers.external_task.DagStateTrigger"
assert kwargs == {
"dag_id": self.DAG_ID,
"states": self.STATES,
"execution_dates": [timezone.datetime(2022, 1, 1)],
"poll_interval": 5,
}

0 comments on commit 6ec97dc

Please sign in to comment.