Skip to content

Commit

Permalink
Adding an example dag for dynamic task mapping (apache#28325)
Browse files Browse the repository at this point in the history
  • Loading branch information
amoghrajesh authored Dec 15, 2022
1 parent 3fa20e0 commit b263dbc
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 47 deletions.
38 changes: 38 additions & 0 deletions airflow/example_dags/example_dynamic_task_mapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Example DAG demonstrating the usage of dynamic task mapping."""
from __future__ import annotations

from datetime import datetime

from airflow import DAG
from airflow.decorators import task

with DAG(dag_id="example_dynamic_task_mapping", start_date=datetime(2022, 3, 4)) as dag:

@task
def add_one(x: int):
return x + 1

@task
def sum_it(values):
total = sum(values)
print(f"Total was {total}")

added_values = add_one.expand(x=[1, 2, 3])
sum_it(added_values)
1 change: 0 additions & 1 deletion docker_tests/test_docker_compose_quick_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from unittest import mock

import requests

from docker_tests.command_utils import run_command
from docker_tests.constants import SOURCE_ROOT
from docker_tests.docker_tests_utils import docker_image
Expand Down
23 changes: 3 additions & 20 deletions docs/apache-airflow/concepts/dynamic-task-mapping.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,27 +30,10 @@ Simple mapping

In its simplest form you can map over a list defined directly in your DAG file using the ``expand()`` function instead of calling your task directly.

.. code-block:: python
from datetime import datetime
from airflow import DAG
from airflow.decorators import task
with DAG(dag_id="simple_mapping", start_date=datetime(2022, 3, 4)) as dag:
@task
def add_one(x: int):
return x + 1
@task
def sum_it(values):
total = sum(values)
print(f"Total was {total}")
If you want to see a simple usage of Dynamic Task Mapping, you can look below:

added_values = add_one.expand(x=[1, 2, 3])
sum_it(added_values)
.. exampleinclude:: /../../airflow/example_dags/example_dynamic_task_mapping.py
:language: python

This will show ``Total was 9`` in the task logs when executed.

Expand Down
5 changes: 2 additions & 3 deletions docs/build_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@
from itertools import filterfalse, tee
from typing import Callable, Iterable, NamedTuple, TypeVar

from rich.console import Console
from tabulate import tabulate

from docs.exts.docs_build import dev_index_generator, lint_checks
from docs.exts.docs_build.code_utils import CONSOLE_WIDTH, PROVIDER_INIT_FILE
from docs.exts.docs_build.docs_builder import DOCS_DIR, AirflowDocsBuilder, get_available_packages
Expand All @@ -36,6 +33,8 @@
from docs.exts.docs_build.github_action_utils import with_group
from docs.exts.docs_build.package_filter import process_package_filters
from docs.exts.docs_build.spelling_checks import SpellingError, display_spelling_error_summary
from rich.console import Console
from tabulate import tabulate

TEXT_RED = "\033[31m"
TEXT_RESET = "\033[0m"
Expand Down
2 changes: 1 addition & 1 deletion docs/exts/docs_build/spelling_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
from functools import total_ordering
from typing import NamedTuple

from docs.exts.docs_build.code_utils import CONSOLE_WIDTH
from rich.console import Console

from airflow.utils.code_utils import prepare_code_snippet
from docs.exts.docs_build.code_utils import CONSOLE_WIDTH

CURRENT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__)))
DOCS_DIR = os.path.abspath(os.path.join(CURRENT_DIR, os.pardir, os.pardir))
Expand Down
77 changes: 55 additions & 22 deletions tests/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from pathlib import Path
from unittest import mock

import attr
import pendulum
import pytest
from dateutil.relativedelta import FR, relativedelta
Expand All @@ -42,6 +43,7 @@
from airflow.kubernetes.pod_generator import PodGenerator
from airflow.models import DAG, Connection, DagBag, Operator
from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
from airflow.models.expandinput import EXPAND_INPUT_EMPTY
from airflow.models.mappedoperator import MappedOperator
from airflow.models.param import Param, ParamsDict
from airflow.models.xcom import XCOM_RETURN_KEY, XCom
Expand Down Expand Up @@ -534,32 +536,47 @@ def validate_deserialized_task(
serialized_task,
task,
):
"""Verify non-airflow operators are casted to BaseOperator."""
assert isinstance(serialized_task, SerializedBaseOperator)
"""Verify non-Airflow operators are casted to BaseOperator or MappedOperator."""
assert not isinstance(task, SerializedBaseOperator)
assert isinstance(task, BaseOperator)
assert isinstance(task, (BaseOperator, MappedOperator))

# Every task should have a task_group property -- even if it's the DAG's root task group
assert serialized_task.task_group

fields_to_check = task.get_serialized_fields() - {
# Checked separately
"_task_type",
"_operator_name",
"subdag",
# Type is excluded, so don't check it
"_log",
# List vs tuple. Check separately
"template_ext",
"template_fields",
# We store the string, real dag has the actual code
"on_failure_callback",
"on_success_callback",
"on_retry_callback",
# Checked separately
"resources",
"params",
}
if isinstance(task, BaseOperator):
assert isinstance(serialized_task, SerializedBaseOperator)
fields_to_check = task.get_serialized_fields() - {
# Checked separately
"_task_type",
"_operator_name",
"subdag",
# Type is excluded, so don't check it
"_log",
# List vs tuple. Check separately
"template_ext",
"template_fields",
# We store the string, real dag has the actual code
"on_failure_callback",
"on_success_callback",
"on_retry_callback",
# Checked separately
"resources",
}
else: # Promised to be mapped by the assert above.
assert isinstance(serialized_task, MappedOperator)
fields_to_check = {f.name for f in attr.fields(MappedOperator)}
fields_to_check -= {
# Matching logic in BaseOperator.get_serialized_fields().
"dag",
"task_group",
# List vs tuple. Check separately.
"operator_extra_links",
"template_ext",
"template_fields",
# Checked separately.
"operator_class",
"partial_kwargs",
}

assert serialized_task.task_type == task.task_type

Expand All @@ -580,9 +597,25 @@ def validate_deserialized_task(
assert serialized_task.resources == task.resources

# Ugly hack as some operators override params var in their init
if isinstance(task.params, ParamsDict):
if isinstance(task.params, ParamsDict) and isinstance(serialized_task.params, ParamsDict):
assert serialized_task.params.dump() == task.params.dump()

if isinstance(task, MappedOperator):
# MappedOperator.operator_class holds a backup of the serialized
# data; checking its entirety basically duplicates this validation
# function, so we just do some satiny checks.
serialized_task.operator_class["_task_type"] == type(task).__name__
serialized_task.operator_class["_operator_name"] == task._operator_name

# Serialization cleans up default values in partial_kwargs, this
# adds them back to both sides.
default_partial_kwargs = (
BaseOperator.partial(task_id="_")._expand(EXPAND_INPUT_EMPTY, strict=False).partial_kwargs
)
serialized_partial_kwargs = {**default_partial_kwargs, **serialized_task.partial_kwargs}
original_partial_kwargs = {**default_partial_kwargs, **task.partial_kwargs}
assert serialized_partial_kwargs == original_partial_kwargs

# Check that for Deserialized task, task.subdag is None for all other Operators
# except for the SubDagOperator where task.subdag is an instance of DAG object
if task.task_type == "SubDagOperator":
Expand Down

0 comments on commit b263dbc

Please sign in to comment.