Skip to content

Commit

Permalink
openlineage: use airflow provided getters from conf (apache#40790)
Browse files Browse the repository at this point in the history
Signed-off-by: Kacper Muda <[email protected]>
  • Loading branch information
kacpermuda authored Jul 16, 2024
1 parent 519b0d0 commit 985ccbc
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 91 deletions.
46 changes: 15 additions & 31 deletions airflow/providers/openlineage/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,9 @@
"""
This module provides functions for safely retrieving and handling OpenLineage configurations.
To prevent errors caused by invalid user-provided configuration values, we use ``conf.get()``
to fetch values as strings and perform safe conversions using custom functions.
Any invalid configuration values should be treated as incorrect and replaced with default values.
For example, if the default for boolean ``custom_ol_var`` is False, any non-true value provided:
``"asdf"``, ``12345``, ``{"key": 1}`` or empty string, will result in False being used.
By using default values for invalid configuration values, we ensure that the configurations are handled
safely, preventing potential runtime errors due to conversion issues.
For the legacy boolean env variables `OPENLINEAGE_AIRFLOW_DISABLE_SOURCE_CODE` and `OPENLINEAGE_DISABLED`,
any string not equal to "true", "1", or "t" should be treated as False, to maintain backward compatibility.
Support for legacy variables will be removed in Airflow 3.
"""

from __future__ import annotations
Expand All @@ -51,13 +45,6 @@ def _is_true(arg: Any) -> bool:
return str(arg).lower().strip() in ("true", "1", "t")


def _safe_int_convert(arg: Any, default: int) -> int:
try:
return int(arg)
except (ValueError, TypeError):
return default


@cache
def config_path(check_legacy_env_var: bool = True) -> str:
"""[openlineage] config_path."""
Expand All @@ -70,11 +57,11 @@ def config_path(check_legacy_env_var: bool = True) -> str:
@cache
def is_source_enabled() -> bool:
"""[openlineage] disable_source_code."""
option = conf.get(_CONFIG_SECTION, "disable_source_code", fallback="")
if not option:
option = os.getenv("OPENLINEAGE_AIRFLOW_DISABLE_SOURCE_CODE", "")
# when disable_source_code is True, is_source_enabled() should be False
return not _is_true(option)
option = conf.getboolean(_CONFIG_SECTION, "disable_source_code", fallback="False")
if option is False: # Check legacy variable
option = _is_true(os.getenv("OPENLINEAGE_AIRFLOW_DISABLE_SOURCE_CODE", ""))
# when disable_source_code is True, is_source_enabled() should be False; hence the "not"
return not option


@cache
Expand All @@ -87,8 +74,7 @@ def disabled_operators() -> set[str]:
@cache
def selective_enable() -> bool:
"""[openlineage] selective_enable."""
option = conf.get(_CONFIG_SECTION, "selective_enable", fallback="")
return _is_true(option)
return conf.getboolean(_CONFIG_SECTION, "selective_enable", fallback="False")


@cache
Expand Down Expand Up @@ -121,13 +107,12 @@ def transport() -> dict[str, Any]:
@cache
def is_disabled() -> bool:
"""[openlineage] disabled + check if any configuration is present."""
option = conf.get(_CONFIG_SECTION, "disabled", fallback="")
if _is_true(option):
if conf.getboolean(_CONFIG_SECTION, "disabled", fallback="False"):
return True

option = os.getenv("OPENLINEAGE_DISABLED", "")
if _is_true(option):
if _is_true(os.getenv("OPENLINEAGE_DISABLED", "")): # Check legacy variable
return True

# Check if both 'transport' and 'config_path' are not present and also
# if legacy 'OPENLINEAGE_URL' environment variables is not set
return transport() == {} and config_path(True) == "" and os.getenv("OPENLINEAGE_URL", "") == ""
Expand All @@ -136,17 +121,16 @@ def is_disabled() -> bool:
@cache
def dag_state_change_process_pool_size() -> int:
"""[openlineage] dag_state_change_process_pool_size."""
option = conf.get(_CONFIG_SECTION, "dag_state_change_process_pool_size", fallback="")
return _safe_int_convert(str(option).strip(), default=1)
return conf.getint(_CONFIG_SECTION, "dag_state_change_process_pool_size", fallback="1")


@cache
def execution_timeout() -> int:
"""[openlineage] execution_timeout."""
option = conf.get(_CONFIG_SECTION, "execution_timeout", fallback="")
return _safe_int_convert(str(option).strip(), default=10)
return conf.getint(_CONFIG_SECTION, "execution_timeout", fallback="10")


@cache
def include_full_task_info() -> bool:
"""[openlineage] include_full_task_info."""
return conf.getboolean(_CONFIG_SECTION, "include_full_task_info", fallback="False")
151 changes: 91 additions & 60 deletions tests/providers/openlineage/test_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
from airflow.exceptions import AirflowConfigException
from airflow.providers.openlineage.conf import (
_is_true,
_safe_int_convert,
config_path,
custom_extractors,
dag_state_change_process_pool_size,
disabled_operators,
execution_timeout,
include_full_task_info,
is_disabled,
is_source_enabled,
Expand All @@ -44,6 +44,7 @@
_VAR_DISABLE_SOURCE_CODE = "OPENLINEAGE_AIRFLOW_DISABLE_SOURCE_CODE"
_CONFIG_OPTION_DISABLE_SOURCE_CODE = "disable_source_code"
_CONFIG_OPTION_DISABLED_FOR_OPERATORS = "disabled_for_operators"
_CONFIG_OPTION_EXECUTION_TIMEOUT = "execution_timeout"
_VAR_EXTRACTORS = "OPENLINEAGE_EXTRACTORS"
_CONFIG_OPTION_EXTRACTORS = "extractors"
_VAR_NAMESPACE = "OPENLINEAGE_NAMESPACE"
Expand Down Expand Up @@ -86,35 +87,6 @@ def test_is_true(var_string, expected):
assert _is_true(var_string) is expected


@pytest.mark.parametrize(
"input_value, expected",
[
("123", 123),
(456, 456),
("789", 789),
(0, 0),
("0", 0),
],
)
def test_safe_int_convert(input_value, expected):
assert _safe_int_convert(input_value, default=1) == expected


@pytest.mark.parametrize(
"input_value, default",
[
("abc", 1),
("", 2),
(None, 3),
("123abc", 4),
([], 5),
("1.2", 6),
],
)
def test_safe_int_convert_erroneous_values(input_value, default):
assert _safe_int_convert(input_value, default) == default


@env_vars({_VAR_CONFIG_PATH: "env_var_path"})
@conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_CONFIG_PATH): None})
def test_config_path_legacy_env_var_is_used_when_no_conf_option_set():
Expand Down Expand Up @@ -161,7 +133,8 @@ def test_disable_source_code_conf_option_has_precedence_over_legacy_env_var():

@conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_DISABLE_SOURCE_CODE): "asdadawlaksnd"})
def test_disable_source_code_conf_option_not_working_for_random_string():
assert is_source_enabled() is True
with pytest.raises(AirflowConfigException):
is_source_enabled()


@env_vars({_VAR_DISABLE_SOURCE_CODE: "asdadawlaksnd"})
Expand All @@ -172,7 +145,8 @@ def test_disable_source_code_legacy_env_var_not_working_for_random_string():

@conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_DISABLE_SOURCE_CODE): ""})
def test_disable_source_code_empty_conf_option():
assert is_source_enabled() is True
with pytest.raises(AirflowConfigException):
is_source_enabled()


@conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_DISABLE_SOURCE_CODE): None})
Expand All @@ -192,12 +166,14 @@ def test_selective_enable(var_string, expected):

@conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_SELECTIVE_ENABLE): "asdadawlaksnd"})
def test_selective_enable_not_working_for_random_string():
assert selective_enable() is False
with pytest.raises(AirflowConfigException):
selective_enable()


@conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_SELECTIVE_ENABLE): ""})
def test_selective_enable_empty_conf_option():
assert selective_enable() is False
with pytest.raises(AirflowConfigException):
selective_enable()


@conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_SELECTIVE_ENABLE): None})
Expand Down Expand Up @@ -346,8 +322,9 @@ def test_is_disabled_possible_values_for_disabling(disabled):
(_CONFIG_SECTION, _CONFIG_OPTION_DISABLED): "asdadawlaksnd",
}
)
def test_is_disabled_is_not_disabled_by_random_string():
assert is_disabled() is False
def test_is_disabled_raises_for_random_string():
with pytest.raises(AirflowConfigException):
is_disabled()


@mock.patch.dict(os.environ, {_VAR_URL: "https://test.com"}, clear=True)
Expand All @@ -358,6 +335,18 @@ def test_is_disabled_is_not_disabled_by_random_string():
(_CONFIG_SECTION, _CONFIG_OPTION_DISABLED): "",
}
)
def test_is_disabled_raises_error_for_empty_string():
with pytest.raises(AirflowConfigException):
is_disabled()


@mock.patch.dict(os.environ, {_VAR_URL: "https://test.com"}, clear=True)
@conf_vars(
{
(_CONFIG_SECTION, _CONFIG_OPTION_CONFIG_PATH): "",
(_CONFIG_SECTION, _CONFIG_OPTION_TRANSPORT): "",
}
)
def test_is_disabled_is_false_when_not_explicitly_disabled_and_url_set():
assert is_disabled() is False

Expand All @@ -367,7 +356,6 @@ def test_is_disabled_is_false_when_not_explicitly_disabled_and_url_set():
{
(_CONFIG_SECTION, _CONFIG_OPTION_CONFIG_PATH): "",
(_CONFIG_SECTION, _CONFIG_OPTION_TRANSPORT): '{"valid": "transport"}',
(_CONFIG_SECTION, _CONFIG_OPTION_DISABLED): "",
}
)
def test_is_disabled_is_false_when_not_explicitly_disabled_and_transport_set():
Expand All @@ -379,7 +367,6 @@ def test_is_disabled_is_false_when_not_explicitly_disabled_and_transport_set():
{
(_CONFIG_SECTION, _CONFIG_OPTION_CONFIG_PATH): "some/path.yml",
(_CONFIG_SECTION, _CONFIG_OPTION_TRANSPORT): "",
(_CONFIG_SECTION, _CONFIG_OPTION_DISABLED): "",
}
)
def test_is_disabled_is_false_when_not_explicitly_disabled_and_config_path_set():
Expand All @@ -403,7 +390,6 @@ def test_is_disabled_conf_option_is_enough_to_disable():
{
(_CONFIG_SECTION, _CONFIG_OPTION_CONFIG_PATH): "some/path.yml",
(_CONFIG_SECTION, _CONFIG_OPTION_TRANSPORT): '{"valid": "transport"}',
(_CONFIG_SECTION, _CONFIG_OPTION_DISABLED): "",
}
)
def test_is_disabled_legacy_env_var_is_enough_to_disable():
Expand Down Expand Up @@ -451,7 +437,6 @@ def test_is_disabled_env_var_true_has_precedence_over_conf_false():
{
(_CONFIG_SECTION, _CONFIG_OPTION_CONFIG_PATH): "",
(_CONFIG_SECTION, _CONFIG_OPTION_TRANSPORT): "",
(_CONFIG_SECTION, _CONFIG_OPTION_DISABLED): "",
}
)
def test_is_disabled_empty_conf_option():
Expand All @@ -476,13 +461,6 @@ def test_is_disabled_do_not_fail_if_conf_option_missing():
("1", 1),
("2 ", 2),
(" 3", 3),
("4.56", 1), # default
("asdf", 1), # default
("true", 1), # default
("false", 1), # default
("None", 1), # default
("", 1), # default
(" ", 1), # default
(None, 1), # default
),
)
Expand All @@ -493,30 +471,83 @@ def test_dag_state_change_process_pool_size(var_string, expected):


@pytest.mark.parametrize(
("var", "expected"),
"var_string",
(
("False", False),
("True", True),
("t", True),
("true", True),
"4.56",
"asdf",
"true",
"false",
"None",
"",
" ",
),
)
def test_include_full_task_info_reads_config(var, expected):
with conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_INCLUDE_FULL_TASK_INFO): var}):
assert include_full_task_info() is expected
def test_dag_state_change_process_pool_size_invalid_value_raise_error(var_string):
with conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_DAG_STATE_CHANGE_PROCESS_POOL_SIZE): var_string}):
with pytest.raises(AirflowConfigException):
dag_state_change_process_pool_size()


@pytest.mark.parametrize(
"var",
[
("var_string", "expected"),
(
("1", 1),
("2 ", 2),
(" 3", 3),
(None, 10), # default
),
)
def test_execution_timeout(var_string, expected):
with conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_EXECUTION_TIMEOUT): var_string}):
result = execution_timeout()
assert result == expected


@pytest.mark.parametrize(
"var_string",
(
"4.56",
"asdf",
"true",
"false",
"None",
"",
" ",
),
)
def test_execution_timeout_invalid_value_raise_error(var_string):
with conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_EXECUTION_TIMEOUT): var_string}):
with pytest.raises(AirflowConfigException):
execution_timeout()


@pytest.mark.parametrize(
("var_string", "expected"),
_BOOL_PARAMS,
)
def test_include_full_task_info(var_string, expected):
with conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_INCLUDE_FULL_TASK_INFO): var_string}):
result = include_full_task_info()
assert result is expected


@pytest.mark.parametrize(
"var_string",
(
"a",
"asdf",
"None",
"31",
"",
" ",
],
),
)
def test_include_full_task_info_raises_exception(var):
with conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_INCLUDE_FULL_TASK_INFO): var}):
def test_include_full_task_info_invalid_value_raise_error(var_string):
with conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_INCLUDE_FULL_TASK_INFO): var_string}):
with pytest.raises(AirflowConfigException):
include_full_task_info()


@conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_INCLUDE_FULL_TASK_INFO): None})
def test_include_full_task_info_do_not_fail_if_conf_option_missing():
assert include_full_task_info() is False

0 comments on commit 985ccbc

Please sign in to comment.