Skip to content

Commit

Permalink
Better @task_group typing powered by ParamSpec and pre-commit (apache…
Browse files Browse the repository at this point in the history
  • Loading branch information
uranusjr authored Sep 23, 2022
1 parent 051ba15 commit 7179eba
Show file tree
Hide file tree
Showing 11 changed files with 539 additions and 493 deletions.
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -465,12 +465,12 @@ repos:
entry: ./scripts/ci/pre_commit/pre_commit_base_operator_partial_arguments.py
pass_filenames: false
files: ^airflow/models/(?:base|mapped)operator\.py$
- id: check-dag-init-decorator-arguments
name: Check DAG and @dag arguments
- id: check-init-decorator-arguments
name: Check model __init__ and decorator arguments are in sync
language: python
entry: ./scripts/ci/pre_commit/pre_commit_sync_dag_init_decorator.py
entry: ./scripts/ci/pre_commit/pre_commit_sync_init_decorator.py
pass_filenames: false
files: ^airflow/models/dag\.py$
files: ^airflow/models/dag\.py$|^airflow/(?:decorators|utils)/task_group.py$
- id: check-base-operator-usage
language: pygrep
name: Check BaseOperator[Link] core imports
Expand Down
4 changes: 2 additions & 2 deletions STATIC_CODE_CHECKS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,6 @@ require Breeze Docker image to be build locally.
+--------------------------------------------------------+------------------------------------------------------------------+---------+
| check-core-deprecation-classes | Verify using of dedicated Airflow deprecation classes in core | |
+--------------------------------------------------------+------------------------------------------------------------------+---------+
| check-dag-init-decorator-arguments | Check DAG and @dag arguments | |
+--------------------------------------------------------+------------------------------------------------------------------+---------+
| check-daysago-import-from-utils | Make sure days_ago is imported from airflow.utils.dates | |
+--------------------------------------------------------+------------------------------------------------------------------+---------+
| check-decorated-operator-implements-custom-name | Check @task decorator implements custom_operator_name | |
Expand All @@ -179,6 +177,8 @@ require Breeze Docker image to be build locally.
+--------------------------------------------------------+------------------------------------------------------------------+---------+
| check-incorrect-use-of-LoggingMixin | Make sure LoggingMixin is not used alone | |
+--------------------------------------------------------+------------------------------------------------------------------+---------+
| check-init-decorator-arguments | Check model __init__ and decorator arguments are in sync | |
+--------------------------------------------------------+------------------------------------------------------------------+---------+
| check-lazy-logging | Check that all logging methods are lazy | |
+--------------------------------------------------------+------------------------------------------------------------------+---------+
| check-merge-conflict | Check that merge conflicts are not being committed | |
Expand Down
80 changes: 36 additions & 44 deletions airflow/decorators/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,40 +15,42 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
A TaskGroup is a collection of closely related tasks on the same DAG that should be grouped
"""Implements the ``@task_group`` function decorator.
When the decorated function is called, a task group will be created to represent
a collection of closely related tasks on the same DAG that should be grouped
together when the DAG is displayed graphically.
"""

from __future__ import annotations

import functools
from inspect import signature
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, cast, overload
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, overload

import attr

from airflow.models.taskmixin import DAGNode
from airflow.typing_compat import ParamSpec
from airflow.utils.task_group import TaskGroup

if TYPE_CHECKING:
from airflow.models.dag import DAG
from airflow.models.mappedoperator import OperatorExpandArgument
from airflow.models.expandinput import OperatorExpandArgument, OperatorExpandKwargsArgument

F = TypeVar("F", bound=Callable)
R = TypeVar("R")
FParams = ParamSpec("FParams")
FReturn = TypeVar("FReturn", None, DAGNode)

task_group_sig = signature(TaskGroup.__init__)


@attr.define
class TaskGroupDecorator(Generic[R]):
""":meta private:"""

function: Callable[..., R | None] = attr.ib(validator=attr.validators.is_callable())
class _TaskGroupFactory(Generic[FParams, FReturn]):
function: Callable[FParams, FReturn] = attr.ib(validator=attr.validators.is_callable())
kwargs: dict[str, Any] = attr.ib(factory=dict)
"""kwargs for the TaskGroup"""

@function.validator
def _validate_function(self, _, f):
def _validate_function(self, _, f: Callable[FParams, FReturn]):
if 'self' in signature(f).parameters:
raise TypeError('@task_group does not support methods')

Expand All @@ -57,13 +59,18 @@ def _validate(self, _, kwargs):
task_group_sig.bind_partial(**kwargs)

def __attrs_post_init__(self):
self.kwargs.setdefault('group_id', self.function.__name__)
if not self.kwargs.get("group_id"):
self.kwargs["group_id"] = self.function.__name__

def _make_task_group(self, **kwargs) -> TaskGroup:
return TaskGroup(**kwargs)
def __call__(self, *args: FParams.args, **kwargs: FParams.kwargs) -> DAGNode:
"""Instantiate the task group.
def __call__(self, *args, **kwargs) -> R | TaskGroup:
with self._make_task_group(add_suffix_on_collision=True, **self.kwargs) as task_group:
This uses the wrapped function to create a task group. Depending on the
return type of the wrapped function, this either returns the last task
in the group, or the group itself, to support task chaining.
"""
with TaskGroup(add_suffix_on_collision=True, **self.kwargs) as task_group:
if self.function.__doc__ and not task_group.tooltip:
task_group.tooltip = self.function.__doc__

Expand All @@ -85,31 +92,17 @@ def __call__(self, *args, **kwargs) -> R | TaskGroup:
# start >> tg >> end
return task_group

def override(self, **kwargs: Any) -> TaskGroupDecorator[R]:
def override(self, **kwargs: Any) -> _TaskGroupFactory[FParams, FReturn]:
return attr.evolve(self, kwargs={**self.kwargs, **kwargs})

def partial(self, **kwargs: Any) -> _TaskGroupFactory[FParams, FReturn]:
raise NotImplementedError("TODO: Implement me")

class Group(Generic[F]):
"""Declaration of a @task_group-decorated callable for type-checking.
An instance of this type inherits the call signature of the decorated
function wrapped in it (not *exactly* since it actually turns the function
into an XComArg-compatible, but there's no way to express that right now),
and provides two additional methods for task-mapping.
def expand(self, **kwargs: OperatorExpandArgument) -> TaskGroup:
raise NotImplementedError("TODO: Implement me")

This type is implemented by ``TaskGroupDecorator`` at runtime.
"""

__call__: F

function: F

# Return value should match F's return type, but that's impossible to declare.
def expand(self, **kwargs: OperatorExpandArgument) -> Any:
...

def partial(self, **kwargs: Any) -> Group[F]:
...
def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> TaskGroup:
raise NotImplementedError("TODO: Implement me")


# This covers the @task_group() case. Annotations are copied from the TaskGroup
Expand All @@ -130,19 +123,18 @@ def task_group(
ui_color: str = "CornflowerBlue",
ui_fgcolor: str = "#000",
add_suffix_on_collision: bool = False,
) -> Callable[[F], Group[F]]:
) -> Callable[FParams, _TaskGroupFactory[FParams, FReturn]]:
...


# This covers the @task_group case (no parentheses).
@overload
def task_group(python_callable: F) -> Group[F]:
def task_group(python_callable: Callable[FParams, FReturn]) -> _TaskGroupFactory[FParams, FReturn]:
...


def task_group(python_callable=None, **tg_kwargs):
"""
Python TaskGroup decorator.
"""Python TaskGroup decorator.
This wraps a function into an Airflow TaskGroup. When used as the
``@task_group()`` form, all arguments are forwarded to the underlying
Expand All @@ -151,6 +143,6 @@ def task_group(python_callable=None, **tg_kwargs):
:param python_callable: Function to decorate.
:param tg_kwargs: Keyword arguments for the TaskGroup object.
"""
if callable(python_callable):
return TaskGroupDecorator(function=python_callable, kwargs=tg_kwargs)
return cast(Callable[[F], F], functools.partial(TaskGroupDecorator, kwargs=tg_kwargs))
if callable(python_callable) and not tg_kwargs:
return _TaskGroupFactory(function=python_callable, kwargs=tg_kwargs)
return functools.partial(_TaskGroupFactory, kwargs=tg_kwargs)
5 changes: 3 additions & 2 deletions airflow/utils/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
DuplicateTaskIdFound,
TaskAlreadyInTaskGroup,
)
from airflow.models.abstractoperator import AbstractOperator
from airflow.models.taskmixin import DAGNode, DependencyMixin
from airflow.serialization.enums import DagAttributeTypes
from airflow.utils.helpers import validate_group_key
Expand Down Expand Up @@ -79,7 +78,7 @@ def __init__(
prefix_group_id: bool = True,
parent_group: TaskGroup | None = None,
dag: DAG | None = None,
default_args: dict | None = None,
default_args: dict[str, Any] | None = None,
tooltip: str = "",
ui_color: str = "CornflowerBlue",
ui_fgcolor: str = "#000",
Expand Down Expand Up @@ -497,6 +496,8 @@ def task_group_to_dict(task_item_or_group):
Create a nested dict representation of this TaskGroup and its children used to construct
the Graph.
"""
from airflow.models.abstractoperator import AbstractOperator

if isinstance(task_item_or_group, AbstractOperator):
return {
'id': task_item_or_group.task_id,
Expand Down
2 changes: 1 addition & 1 deletion dev/breeze/src/airflow_breeze/pre_commit_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
'check-builtin-literals',
'check-changelog-has-no-duplicates',
'check-core-deprecation-classes',
'check-dag-init-decorator-arguments',
'check-daysago-import-from-utils',
'check-decorated-operator-implements-custom-name',
'check-docstring-param-types',
Expand All @@ -46,6 +45,7 @@
'check-for-inclusive-language',
'check-hooks-apply',
'check-incorrect-use-of-LoggingMixin',
'check-init-decorator-arguments',
'check-lazy-logging',
'check-merge-conflict',
'check-newsfragments-are-valid',
Expand Down
2 changes: 1 addition & 1 deletion images/breeze/output-commands-hash.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ setup:version:d11da4c17a23179830079b646160149c
setup:09e876968e669155b4aae7423a19e7b2
shell:4de9c18e581853f332767beddb95b425
start-airflow:eef91445684e015f83d91d02f4f03ccc
static-checks:6bf06066680e36de71ccff5f7ba0dec2
static-checks:425cd78507278494e345fb7648260c24
stop:8ebd8a42f1003495d37b884de5ac7ce6
testing:docker-compose-tests:3e07be65e30219930d3c62a593dd8c6a
testing:helm-tests:403231f0a94b261f9c7aae8aea03ec50
Expand Down
Loading

0 comments on commit 7179eba

Please sign in to comment.