Skip to content

Commit

Permalink
Resolve aws provider deprecations in tests (apache#40026)
Browse files Browse the repository at this point in the history
  • Loading branch information
dirrao authored Jun 3, 2024
1 parent 19c145c commit 4849fef
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 24 deletions.
6 changes: 0 additions & 6 deletions tests/deprecations_ignore.yml
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,6 @@
- tests/providers/amazon/aws/deferrable/hooks/test_redshift_cluster.py::TestRedshiftAsyncHook::test_pause_cluster
- tests/providers/amazon/aws/deferrable/hooks/test_redshift_cluster.py::TestRedshiftAsyncHook::test_resume_cluster
- tests/providers/amazon/aws/deferrable/hooks/test_redshift_cluster.py::TestRedshiftAsyncHook::test_resume_cluster_exception
- tests/providers/amazon/aws/hooks/test_sagemaker.py::TestSageMakerHook::test_start_pipeline_waits_for_completion
- tests/providers/amazon/aws/hooks/test_sagemaker.py::TestSageMakerHook::test_stop_pipeline_waits_for_completion
- tests/providers/amazon/aws/hooks/test_sagemaker.py::TestSageMakerHook::test_stop_pipeline_waits_for_completion_even_when_already_stopped
- tests/providers/amazon/aws/operators/test_appflow.py::test_base_aws_op_attributes
- tests/providers/amazon/aws/operators/test_appflow.py::test_run
- tests/providers/amazon/aws/operators/test_base_aws.py::TestAwsBaseOperator::test_conflicting_region_name
Expand Down Expand Up @@ -235,12 +232,9 @@
- tests/providers/amazon/aws/secrets/test_secrets_manager.py::TestSecretsManagerBackend::test_get_connection_broken_field_mode_extra_allows_nested_json
- tests/providers/amazon/aws/secrets/test_secrets_manager.py::TestSecretsManagerBackend::test_get_connection_broken_field_mode_url_encoding
- tests/providers/amazon/aws/sensors/test_base_aws.py::TestAwsBaseSensor::test_conflicting_region_name
- tests/providers/amazon/aws/sensors/test_ecs.py::TestEcsBaseSensor::test_hook_and_client
- tests/providers/amazon/aws/sensors/test_ecs.py::TestEcsBaseSensor::test_initialise_operator
- tests/providers/amazon/aws/triggers/test_redshift_cluster.py::TestRedshiftClusterTrigger::test_redshift_cluster_sensor_trigger_exception
- tests/providers/amazon/aws/triggers/test_redshift_cluster.py::TestRedshiftClusterTrigger::test_redshift_cluster_sensor_trigger_resuming_status
- tests/providers/amazon/aws/triggers/test_redshift_cluster.py::TestRedshiftClusterTrigger::test_redshift_cluster_sensor_trigger_success
- tests/providers/amazon/aws/utils/test_connection_wrapper.py::TestAwsConnectionWrapper::test_get_endpoint_url_from_extra
- tests/providers/apache/spark/operators/test_spark_sql.py::TestSparkSqlOperator::test_execute
- tests/providers/common/sql/hooks/test_dbapi.py::TestDbApiHook::test_instance_check_works_for_legacy_db_api_hook
- tests/providers/common/sql/operators/test_sql.py::TestSQLCheckOperatorDbHook::test_get_hook
Expand Down
28 changes: 20 additions & 8 deletions tests/providers/amazon/aws/hooks/test_sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from dateutil.tz import tzlocal
from moto import mock_aws

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.hooks.sagemaker import (
Expand Down Expand Up @@ -737,7 +737,11 @@ def test_start_pipeline_waits_for_completion(self, mock_conn):
]

hook = SageMakerHook(aws_conn_id="aws_default")
hook.start_pipeline(pipeline_name="test_name", wait_for_completion=True, check_interval=0)
with pytest.warns(
AirflowProviderDeprecationWarning,
match="parameter `wait_for_completion` and `check_interval` are deprecated, remove them and call check_status yourself if you want to wait for completion",
):
hook.start_pipeline(pipeline_name="test_name", wait_for_completion=True, check_interval=0)

assert mock_conn().describe_pipeline_execution.call_count == 3

Expand All @@ -760,9 +764,13 @@ def test_stop_pipeline_waits_for_completion(self, mock_conn):
]

hook = SageMakerHook(aws_conn_id="aws_default")
pipeline_status = hook.stop_pipeline(
pipeline_exec_arn="test", wait_for_completion=True, check_interval=0
)
with pytest.warns(
AirflowProviderDeprecationWarning,
match="parameter `wait_for_completion` and `check_interval` are deprecated, remove them and call check_status yourself if you want to wait for completion",
):
pipeline_status = hook.stop_pipeline(
pipeline_exec_arn="test", wait_for_completion=True, check_interval=0
)

assert pipeline_status == "Stopped"
assert mock_conn().describe_pipeline_execution.call_count == 3
Expand All @@ -781,9 +789,13 @@ def test_stop_pipeline_waits_for_completion_even_when_already_stopped(self, mock
]

hook = SageMakerHook(aws_conn_id="aws_default")
pipeline_status = hook.stop_pipeline(
pipeline_exec_arn="test", wait_for_completion=True, check_interval=0
)
with pytest.warns(
AirflowProviderDeprecationWarning,
match="parameter `wait_for_completion` and `check_interval` are deprecated, remove them and call check_status yourself if you want to wait for completion",
):
pipeline_status = hook.stop_pipeline(
pipeline_exec_arn="test", wait_for_completion=True, check_interval=0
)

assert pipeline_status == "Stopped"

Expand Down
6 changes: 3 additions & 3 deletions tests/providers/amazon/aws/sensors/test_ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,18 +74,18 @@ class TestEcsBaseSensor(EcsBaseTestCase):
@pytest.mark.parametrize("region_name", [None, NOTSET, "ca-central-1"])
def test_initialise_operator(self, aws_conn_id, region_name):
"""Test sensor initialize."""
op_kw = {"aws_conn_id": aws_conn_id, "region": region_name}
op_kw = {"aws_conn_id": aws_conn_id, "region_name": region_name}
op_kw = {k: v for k, v in op_kw.items() if v is not NOTSET}
op = EcsBaseSensor(task_id="test_ecs_base", **op_kw)

assert op.aws_conn_id == (aws_conn_id if aws_conn_id is not NOTSET else "aws_default")
assert op.region == (region_name if region_name is not NOTSET else None)
assert op.region_name == (region_name if region_name is not NOTSET else None)

@pytest.mark.parametrize("aws_conn_id", [None, NOTSET, "aws_test_conn"])
@pytest.mark.parametrize("region_name", [None, NOTSET, "ca-central-1"])
def test_hook_and_client(self, aws_conn_id, region_name):
"""Test initialize ``EcsHook`` and ``boto3.client``."""
op_kw = {"aws_conn_id": aws_conn_id, "region": region_name}
op_kw = {"aws_conn_id": aws_conn_id, "region_name": region_name}
op_kw = {k: v for k, v in op_kw.items() if v is not NOTSET}
op = EcsBaseSensor(task_id="test_ecs_base_hook_client", **op_kw)

Expand Down
13 changes: 6 additions & 7 deletions tests/providers/amazon/aws/utils/test_connection_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,16 +352,15 @@ def test_get_botocore_config(self, mock_botocore_config, botocore_config, botoco
def test_get_endpoint_url_from_extra(self, extra, expected):
mock_conn = mock_connection_factory(extra=extra)
expected_deprecation_message = (
"extra['host'] is deprecated and will be removed in a future release."
" Please set extra['endpoint_url'] instead"
r"extra\['host'\] is deprecated and will be removed in a future release."
r" Please set extra\['endpoint_url'\] instead"
)

with warnings.catch_warnings(record=True) as records:
wrap_conn = AwsConnectionWrapper(conn=mock_conn)

if extra.get("host"):
assert len(records) == 1
assert str(records[0].message) == expected_deprecation_message
with pytest.warns(AirflowProviderDeprecationWarning, match=expected_deprecation_message):
wrap_conn = AwsConnectionWrapper(conn=mock_conn)
else:
wrap_conn = AwsConnectionWrapper(conn=mock_conn)

assert wrap_conn.endpoint_url == expected

Expand Down

0 comments on commit 4849fef

Please sign in to comment.