From b263dbcb0f84fd9029591d1447a7c843cb970f15 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Fri, 16 Dec 2022 01:59:15 +0530 Subject: [PATCH] Adding an example dag for dynamic task mapping (#28325) --- .../example_dynamic_task_mapping.py | 38 +++++++++ .../test_docker_compose_quick_start.py | 1 - .../concepts/dynamic-task-mapping.rst | 23 +----- docs/build_docs.py | 5 +- docs/exts/docs_build/spelling_checks.py | 2 +- tests/serialization/test_dag_serialization.py | 77 +++++++++++++------ 6 files changed, 99 insertions(+), 47 deletions(-) create mode 100644 airflow/example_dags/example_dynamic_task_mapping.py diff --git a/airflow/example_dags/example_dynamic_task_mapping.py b/airflow/example_dags/example_dynamic_task_mapping.py new file mode 100644 index 0000000000000..dce6cda20972c --- /dev/null +++ b/airflow/example_dags/example_dynamic_task_mapping.py @@ -0,0 +1,38 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Example DAG demonstrating the usage of dynamic task mapping.""" +from __future__ import annotations + +from datetime import datetime + +from airflow import DAG +from airflow.decorators import task + +with DAG(dag_id="example_dynamic_task_mapping", start_date=datetime(2022, 3, 4)) as dag: + + @task + def add_one(x: int): + return x + 1 + + @task + def sum_it(values): + total = sum(values) + print(f"Total was {total}") + + added_values = add_one.expand(x=[1, 2, 3]) + sum_it(added_values) diff --git a/docker_tests/test_docker_compose_quick_start.py b/docker_tests/test_docker_compose_quick_start.py index 6f25f625788d4..fd553ed1756eb 100644 --- a/docker_tests/test_docker_compose_quick_start.py +++ b/docker_tests/test_docker_compose_quick_start.py @@ -27,7 +27,6 @@ from unittest import mock import requests - from docker_tests.command_utils import run_command from docker_tests.constants import SOURCE_ROOT from docker_tests.docker_tests_utils import docker_image diff --git a/docs/apache-airflow/concepts/dynamic-task-mapping.rst b/docs/apache-airflow/concepts/dynamic-task-mapping.rst index 5ae0e9fb82031..d15c0ada77155 100644 --- a/docs/apache-airflow/concepts/dynamic-task-mapping.rst +++ b/docs/apache-airflow/concepts/dynamic-task-mapping.rst @@ -30,27 +30,10 @@ Simple mapping In its simplest form you can map over a list defined directly in your DAG file using the ``expand()`` function instead of calling your task directly. -.. code-block:: python - - from datetime import datetime - - from airflow import DAG - from airflow.decorators import task - - - with DAG(dag_id="simple_mapping", start_date=datetime(2022, 3, 4)) as dag: - - @task - def add_one(x: int): - return x + 1 - - @task - def sum_it(values): - total = sum(values) - print(f"Total was {total}") +If you want to see a simple usage of Dynamic Task Mapping, you can look below: - added_values = add_one.expand(x=[1, 2, 3]) - sum_it(added_values) +.. exampleinclude:: /../../airflow/example_dags/example_dynamic_task_mapping.py + :language: python This will show ``Total was 9`` in the task logs when executed. diff --git a/docs/build_docs.py b/docs/build_docs.py index d1fb06ccacad1..cd6c83249d4f3 100755 --- a/docs/build_docs.py +++ b/docs/build_docs.py @@ -25,9 +25,6 @@ from itertools import filterfalse, tee from typing import Callable, Iterable, NamedTuple, TypeVar -from rich.console import Console -from tabulate import tabulate - from docs.exts.docs_build import dev_index_generator, lint_checks from docs.exts.docs_build.code_utils import CONSOLE_WIDTH, PROVIDER_INIT_FILE from docs.exts.docs_build.docs_builder import DOCS_DIR, AirflowDocsBuilder, get_available_packages @@ -36,6 +33,8 @@ from docs.exts.docs_build.github_action_utils import with_group from docs.exts.docs_build.package_filter import process_package_filters from docs.exts.docs_build.spelling_checks import SpellingError, display_spelling_error_summary +from rich.console import Console +from tabulate import tabulate TEXT_RED = "\033[31m" TEXT_RESET = "\033[0m" diff --git a/docs/exts/docs_build/spelling_checks.py b/docs/exts/docs_build/spelling_checks.py index bbaa9fa5dde79..f89bfa50dc587 100644 --- a/docs/exts/docs_build/spelling_checks.py +++ b/docs/exts/docs_build/spelling_checks.py @@ -21,10 +21,10 @@ from functools import total_ordering from typing import NamedTuple +from docs.exts.docs_build.code_utils import CONSOLE_WIDTH from rich.console import Console from airflow.utils.code_utils import prepare_code_snippet -from docs.exts.docs_build.code_utils import CONSOLE_WIDTH CURRENT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__))) DOCS_DIR = os.path.abspath(os.path.join(CURRENT_DIR, os.pardir, os.pardir)) diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 44411f5c075e8..ec07d609541e1 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -30,6 +30,7 @@ from pathlib import Path from unittest import mock +import attr import pendulum import pytest from dateutil.relativedelta import FR, relativedelta @@ -42,6 +43,7 @@ from airflow.kubernetes.pod_generator import PodGenerator from airflow.models import DAG, Connection, DagBag, Operator from airflow.models.baseoperator import BaseOperator, BaseOperatorLink +from airflow.models.expandinput import EXPAND_INPUT_EMPTY from airflow.models.mappedoperator import MappedOperator from airflow.models.param import Param, ParamsDict from airflow.models.xcom import XCOM_RETURN_KEY, XCom @@ -534,32 +536,47 @@ def validate_deserialized_task( serialized_task, task, ): - """Verify non-airflow operators are casted to BaseOperator.""" - assert isinstance(serialized_task, SerializedBaseOperator) + """Verify non-Airflow operators are casted to BaseOperator or MappedOperator.""" assert not isinstance(task, SerializedBaseOperator) - assert isinstance(task, BaseOperator) + assert isinstance(task, (BaseOperator, MappedOperator)) # Every task should have a task_group property -- even if it's the DAG's root task group assert serialized_task.task_group - fields_to_check = task.get_serialized_fields() - { - # Checked separately - "_task_type", - "_operator_name", - "subdag", - # Type is excluded, so don't check it - "_log", - # List vs tuple. Check separately - "template_ext", - "template_fields", - # We store the string, real dag has the actual code - "on_failure_callback", - "on_success_callback", - "on_retry_callback", - # Checked separately - "resources", - "params", - } + if isinstance(task, BaseOperator): + assert isinstance(serialized_task, SerializedBaseOperator) + fields_to_check = task.get_serialized_fields() - { + # Checked separately + "_task_type", + "_operator_name", + "subdag", + # Type is excluded, so don't check it + "_log", + # List vs tuple. Check separately + "template_ext", + "template_fields", + # We store the string, real dag has the actual code + "on_failure_callback", + "on_success_callback", + "on_retry_callback", + # Checked separately + "resources", + } + else: # Promised to be mapped by the assert above. + assert isinstance(serialized_task, MappedOperator) + fields_to_check = {f.name for f in attr.fields(MappedOperator)} + fields_to_check -= { + # Matching logic in BaseOperator.get_serialized_fields(). + "dag", + "task_group", + # List vs tuple. Check separately. + "operator_extra_links", + "template_ext", + "template_fields", + # Checked separately. + "operator_class", + "partial_kwargs", + } assert serialized_task.task_type == task.task_type @@ -580,9 +597,25 @@ def validate_deserialized_task( assert serialized_task.resources == task.resources # Ugly hack as some operators override params var in their init - if isinstance(task.params, ParamsDict): + if isinstance(task.params, ParamsDict) and isinstance(serialized_task.params, ParamsDict): assert serialized_task.params.dump() == task.params.dump() + if isinstance(task, MappedOperator): + # MappedOperator.operator_class holds a backup of the serialized + # data; checking its entirety basically duplicates this validation + # function, so we just do some satiny checks. + serialized_task.operator_class["_task_type"] == type(task).__name__ + serialized_task.operator_class["_operator_name"] == task._operator_name + + # Serialization cleans up default values in partial_kwargs, this + # adds them back to both sides. + default_partial_kwargs = ( + BaseOperator.partial(task_id="_")._expand(EXPAND_INPUT_EMPTY, strict=False).partial_kwargs + ) + serialized_partial_kwargs = {**default_partial_kwargs, **serialized_task.partial_kwargs} + original_partial_kwargs = {**default_partial_kwargs, **task.partial_kwargs} + assert serialized_partial_kwargs == original_partial_kwargs + # Check that for Deserialized task, task.subdag is None for all other Operators # except for the SubDagOperator where task.subdag is an instance of DAG object if task.task_type == "SubDagOperator":