Skip to content

Commit

Permalink
Improve validation of Group id (apache#17578)
Browse files Browse the repository at this point in the history
When Group id of task group is used to prefix task id, it should
follow the same limitation that task_id has, plus it should not
have '.'. The '.' is used to separate groups in task id
so it should not be allowed in the group id.

If this is not checked at Task Group creation time, users will
get messages about invalid task id during deserialization
and it's not entirely obvoius where the error came from
and it crashes the scheduler..

Also this validation will be performed at parsing time, rather
than at deserialization time and the DAG will not even get
serialized, so it will not crash the scheduler.

Fixes: apache#17568
  • Loading branch information
potiuk authored Aug 13, 2021
1 parent 7db43f7 commit 833e109
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 12 deletions.
23 changes: 17 additions & 6 deletions airflow/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,25 +35,36 @@
from airflow.models import TaskInstance

KEY_REGEX = re.compile(r'^[\w.-]+$')
GROUP_KEY_REGEX = re.compile(r'^[\w-]+$')
CAMELCASE_TO_SNAKE_CASE_REGEX = re.compile(r'(?!^)([A-Z]+)')

T = TypeVar('T')
S = TypeVar('S')


def validate_key(k: str, max_length: int = 250) -> bool:
def validate_key(k: str, max_length: int = 250):
"""Validates value used as a key."""
if not isinstance(k, str):
raise TypeError("The key has to be a string")
elif len(k) > max_length:
raise TypeError(f"The key has to be a string and is {type(k)}:{k}")
if len(k) > max_length:
raise AirflowException(f"The key has to be less than {max_length} characters")
elif not KEY_REGEX.match(k):
if not KEY_REGEX.match(k):
raise AirflowException(
"The key ({k}) has to be made of alphanumeric characters, dashes, "
"dots and underscores exclusively".format(k=k)
)
else:
return True


def validate_group_key(k: str, max_length: int = 200):
"""Validates value used as a group key."""
if not isinstance(k, str):
raise TypeError(f"The key has to be a string and is {type(k)}:{k}")
if len(k) > max_length:
raise AirflowException(f"The key has to be less than {max_length} characters")
if not GROUP_KEY_REGEX.match(k):
raise AirflowException(
f"The key ({k}) has to be made of alphanumeric characters, dashes " "and underscores exclusively"
)


def alchemy_to_dict(obj: Any) -> Optional[Dict]:
Expand Down
14 changes: 10 additions & 4 deletions airflow/utils/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from airflow.exceptions import AirflowException, DuplicateTaskIdFound
from airflow.models.taskmixin import TaskMixin
from airflow.utils.helpers import validate_group_key

if TYPE_CHECKING:
from airflow.models.baseoperator import BaseOperator
Expand Down Expand Up @@ -94,10 +95,15 @@ def __init__(
self.used_group_ids: Set[Optional[str]] = set()
self._parent_group = None
else:
if not isinstance(group_id, str):
raise ValueError("group_id must be str")
if not group_id:
raise ValueError("group_id must not be empty")
if prefix_group_id:
# If group id is used as prefix, it should not contain spaces nor dots
# because it is used as prefix in the task_id
validate_group_key(group_id)
else:
if not isinstance(group_id, str):
raise ValueError("group_id must be str")
if not group_id:
raise ValueError("group_id must not be empty")

dag = dag or DagContext.get_current_dag()

Expand Down
76 changes: 74 additions & 2 deletions tests/utils/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,19 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import re
import unittest
from datetime import datetime

import pytest
from parameterized import parameterized

from airflow import AirflowException
from airflow.models import TaskInstance
from airflow.models.dag import DAG
from airflow.operators.dummy import DummyOperator
from airflow.utils import helpers
from airflow.utils.helpers import build_airflow_url_with_query, merge_dicts
from airflow.utils.helpers import build_airflow_url_with_query, merge_dicts, validate_group_key, validate_key
from tests.test_utils.config import conf_vars


Expand Down Expand Up @@ -154,3 +156,73 @@ def test_build_airflow_url_with_query(self):

with cached_app(testing=True).test_request_context():
assert build_airflow_url_with_query(query) == expected_url

@parameterized.expand(
[
(3, "The key has to be a string and is <class 'int'>:3", TypeError),
(None, "The key has to be a string and is <class 'NoneType'>:None", TypeError),
("simple_key", None, None),
("simple-key", None, None),
("group.simple_key", None, None),
("root.group.simple-key", None, None),
(
"key with space",
"The key (key with space) has to be made of alphanumeric "
"characters, dashes, dots and underscores exclusively",
AirflowException,
),
(
"key_with_!",
"The key (key_with_!) has to be made of alphanumeric "
"characters, dashes, dots and underscores exclusively",
AirflowException,
),
(' ' * 251, "The key has to be less than 250 characters", AirflowException),
]
)
def test_validate_key(self, key_id, message, exception):
if message:
with pytest.raises(exception, match=re.escape(message)):
validate_key(key_id)
else:
validate_key(key_id)

@parameterized.expand(
[
(3, "The key has to be a string and is <class 'int'>:3", TypeError),
(None, "The key has to be a string and is <class 'NoneType'>:None", TypeError),
("simple_key", None, None),
("simple-key", None, None),
(
"group.simple_key",
"The key (group.simple_key) has to be made of alphanumeric "
"characters, dashes and underscores exclusively",
AirflowException,
),
(
"root.group-name.simple_key",
"The key (root.group-name.simple_key) has to be made of alphanumeric "
"characters, dashes and underscores exclusively",
AirflowException,
),
(
"key with space",
"The key (key with space) has to be made of alphanumeric "
"characters, dashes and underscores exclusively",
AirflowException,
),
(
"key_with_!",
"The key (key_with_!) has to be made of alphanumeric "
"characters, dashes and underscores exclusively",
AirflowException,
),
(' ' * 201, "The key has to be less than 200 characters", AirflowException),
]
)
def test_validate_group_key(self, key_id, message, exception):
if message:
with pytest.raises(exception, match=re.escape(message)):
validate_group_key(key_id)
else:
validate_group_key(key_id)

0 comments on commit 833e109

Please sign in to comment.