Skip to content

Commit

Permalink
Migrate amazon provider sensor tests from unittests to pytest (ap…
Browse files Browse the repository at this point in the history
  • Loading branch information
IAL32 authored Dec 6, 2022
1 parent 5fec787 commit b726d8e
Show file tree
Hide file tree
Showing 22 changed files with 123 additions and 158 deletions.
11 changes: 5 additions & 6 deletions tests/providers/amazon/aws/sensors/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations

import unittest
from unittest import mock

import pytest
Expand All @@ -27,8 +26,8 @@
from airflow.providers.amazon.aws.sensors.athena import AthenaSensor


class TestAthenaSensor(unittest.TestCase):
def setUp(self):
class TestAthenaSensor:
def setup_method(self):
self.sensor = AthenaSensor(
task_id="test_athena_sensor",
query_execution_id="abc",
Expand All @@ -39,15 +38,15 @@ def setUp(self):

@mock.patch.object(AthenaHook, "poll_query_status", side_effect=("SUCCEEDED",))
def test_poke_success(self, mock_poll_query_status):
assert self.sensor.poke({})
assert self.sensor.poke({}) is True

@mock.patch.object(AthenaHook, "poll_query_status", side_effect=("RUNNING",))
def test_poke_running(self, mock_poll_query_status):
assert not self.sensor.poke({})
assert self.sensor.poke({}) is False

@mock.patch.object(AthenaHook, "poll_query_status", side_effect=("QUEUED",))
def test_poke_queued(self, mock_poll_query_status):
assert not self.sensor.poke({})
assert self.sensor.poke({}) is False

@mock.patch.object(AthenaHook, "poll_query_status", side_effect=("FAILED",))
def test_poke_failed(self, mock_poll_query_status):
Expand Down
49 changes: 20 additions & 29 deletions tests/providers/amazon/aws/sensors/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@
# under the License.
from __future__ import annotations

import unittest
from unittest import mock

import pytest
from parameterized import parameterized

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
Expand All @@ -34,8 +32,8 @@
JOB_ID = "8222a1c2-b246-4e19-b1b8-0039bb4407c0"


class TestBatchSensor(unittest.TestCase):
def setUp(self):
class TestBatchSensor:
def setup_method(self):
self.batch_sensor = BatchSensor(
task_id="batch_job_sensor",
job_id=JOB_ID,
Expand All @@ -44,45 +42,38 @@ def setUp(self):
@mock.patch.object(BatchClientHook, "get_job_description")
def test_poke_on_success_state(self, mock_get_job_description):
mock_get_job_description.return_value = {"status": "SUCCEEDED"}
self.assertTrue(self.batch_sensor.poke({}))
assert self.batch_sensor.poke({})
mock_get_job_description.assert_called_once_with(JOB_ID)

@mock.patch.object(BatchClientHook, "get_job_description")
def test_poke_on_failure_state(self, mock_get_job_description):
mock_get_job_description.return_value = {"status": "FAILED"}
with self.assertRaises(AirflowException) as e:
with pytest.raises(AirflowException, match="Batch sensor failed. AWS Batch job status: FAILED"):
self.batch_sensor.poke({})

self.assertEqual("Batch sensor failed. AWS Batch job status: FAILED", str(e.exception))
mock_get_job_description.assert_called_once_with(JOB_ID)

@mock.patch.object(BatchClientHook, "get_job_description")
def test_poke_on_invalid_state(self, mock_get_job_description):
mock_get_job_description.return_value = {"status": "INVALID"}
with self.assertRaises(AirflowException) as e:
with pytest.raises(
AirflowException, match="Batch sensor failed. Unknown AWS Batch job status: INVALID"
):
self.batch_sensor.poke({})

self.assertEqual("Batch sensor failed. Unknown AWS Batch job status: INVALID", str(e.exception))
mock_get_job_description.assert_called_once_with(JOB_ID)

@parameterized.expand(
[
("SUBMITTED",),
("PENDING",),
("RUNNABLE",),
("STARTING",),
("RUNNING",),
]
)
@pytest.mark.parametrize("job_status", ["SUBMITTED", "PENDING", "RUNNABLE", "STARTING", "RUNNING"])
@mock.patch.object(BatchClientHook, "get_job_description")
def test_poke_on_intermediate_state(self, job_status, mock_get_job_description):
def test_poke_on_intermediate_state(self, mock_get_job_description, job_status):
print(job_status)
mock_get_job_description.return_value = {"status": job_status}
self.assertFalse(self.batch_sensor.poke({}))
assert self.batch_sensor.poke({}) is False
mock_get_job_description.assert_called_once_with(JOB_ID)


class TestBatchComputeEnvironmentSensor(unittest.TestCase):
def setUp(self):
class TestBatchComputeEnvironmentSensor:
def setup_method(self):
self.environment_name = "environment_name"
self.sensor = BatchComputeEnvironmentSensor(
task_id="test_batch_compute_environment_sensor",
Expand All @@ -104,7 +95,7 @@ def test_poke_valid(self, mock_batch_client):
mock_batch_client.describe_compute_environments.return_value = {
"computeEnvironments": [{"status": "VALID"}]
}
assert self.sensor.poke({})
assert self.sensor.poke({}) is True
mock_batch_client.describe_compute_environments.assert_called_once_with(
computeEnvironments=[self.environment_name],
)
Expand All @@ -118,7 +109,7 @@ def test_poke_running(self, mock_batch_client):
}
]
}
assert not self.sensor.poke({})
assert self.sensor.poke({}) is False
mock_batch_client.describe_compute_environments.assert_called_once_with(
computeEnvironments=[self.environment_name],
)
Expand All @@ -140,8 +131,8 @@ def test_poke_invalid(self, mock_batch_client):
assert "AWS Batch compute environment failed" in str(ctx.value)


class TestBatchJobQueueSensor(unittest.TestCase):
def setUp(self):
class TestBatchJobQueueSensor:
def setup_method(self):
self.job_queue = "job_queue"
self.sensor = BatchJobQueueSensor(
task_id="test_batch_job_queue_sensor",
Expand All @@ -162,15 +153,15 @@ def test_poke_no_queue(self, mock_batch_client):
def test_poke_no_queue_with_treat_non_existing_as_deleted(self, mock_batch_client):
self.sensor.treat_non_existing_as_deleted = True
mock_batch_client.describe_job_queues.return_value = {"jobQueues": []}
assert self.sensor.poke({})
assert self.sensor.poke({}) is True
mock_batch_client.describe_job_queues.assert_called_once_with(
jobQueues=[self.job_queue],
)

@mock.patch.object(BatchClientHook, "client")
def test_poke_valid(self, mock_batch_client):
mock_batch_client.describe_job_queues.return_value = {"jobQueues": [{"status": "VALID"}]}
assert self.sensor.poke({})
assert self.sensor.poke({}) is True
mock_batch_client.describe_job_queues.assert_called_once_with(
jobQueues=[self.job_queue],
)
Expand All @@ -184,7 +175,7 @@ def test_poke_running(self, mock_batch_client):
}
]
}
assert not self.sensor.poke({})
assert self.sensor.poke({}) is False
mock_batch_client.describe_job_queues.assert_called_once_with(
jobQueues=[self.job_queue],
)
Expand Down
8 changes: 2 additions & 6 deletions tests/providers/amazon/aws/sensors/test_cloud_formation.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,10 @@ def test_poke_stack_in_unsuccessful_state(self):
self.cloudformation_client_mock.describe_stacks.return_value = {
"Stacks": [{"StackStatus": "bar"}]
}
with pytest.raises(ValueError) as ctx:
with pytest.raises(ValueError, match="Stack foo in bad state: bar"):
op = CloudFormationCreateStackSensor(task_id="task", stack_name="foo")
op.poke({})

assert "Stack foo in bad state: bar" == str(ctx.value)


class TestCloudFormationDeleteStackSensor:
task_id = "test_cloudformation_cluster_delete_sensor"
Expand Down Expand Up @@ -105,12 +103,10 @@ def test_poke_stack_in_unsuccessful_state(self):
self.cloudformation_client_mock.describe_stacks.return_value = {
"Stacks": [{"StackStatus": "bar"}]
}
with pytest.raises(ValueError) as ctx:
with pytest.raises(ValueError, match="Stack foo in bad state: bar"):
op = CloudFormationDeleteStackSensor(task_id="task", stack_name="foo")
op.poke({})

assert "Stack foo in bad state: bar" == str(ctx.value)

@mock_cloudformation
def test_poke_stack_does_not_exist(self):
op = CloudFormationDeleteStackSensor(task_id="task", stack_name="foo")
Expand Down
5 changes: 2 additions & 3 deletions tests/providers/amazon/aws/sensors/test_dms_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# under the License.
from __future__ import annotations

import unittest
from unittest import mock

import pytest
Expand All @@ -26,8 +25,8 @@
from airflow.providers.amazon.aws.sensors.dms import DmsTaskCompletedSensor


class TestDmsTaskCompletedSensor(unittest.TestCase):
def setUp(self):
class TestDmsTaskCompletedSensor:
def setup_method(self):
self.sensor = DmsTaskCompletedSensor(
task_id="test_dms_sensor",
aws_conn_id="aws_default",
Expand Down
12 changes: 6 additions & 6 deletions tests/providers/amazon/aws/sensors/test_eks.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,15 @@ def setUp(self):

@mock.patch.object(EksHook, "get_cluster_state", return_value=ClusterStates.ACTIVE)
def test_poke_reached_target_state(self, mock_get_cluster_state, setUp):
assert self.sensor.poke({})
assert self.sensor.poke({}) is True
mock_get_cluster_state.assert_called_once_with(clusterName=CLUSTER_NAME)

@mock.patch("airflow.providers.amazon.aws.hooks.eks.EksHook.get_cluster_state")
@pytest.mark.parametrize("pending_state", CLUSTER_PENDING_STATES)
def test_poke_reached_pending_state(self, mock_get_cluster_state, setUp, pending_state):
mock_get_cluster_state.return_value = pending_state

assert not self.sensor.poke({})
assert self.sensor.poke({}) is False
mock_get_cluster_state.assert_called_once_with(clusterName=CLUSTER_NAME)

@mock.patch("airflow.providers.amazon.aws.hooks.eks.EksHook.get_cluster_state")
Expand Down Expand Up @@ -104,7 +104,7 @@ def setUp(self):

@mock.patch.object(EksHook, "get_fargate_profile_state", return_value=FargateProfileStates.ACTIVE)
def test_poke_reached_target_state(self, mock_get_fargate_profile_state, setUp):
assert self.sensor.poke({})
assert self.sensor.poke({}) is True
mock_get_fargate_profile_state.assert_called_once_with(
clusterName=CLUSTER_NAME, fargateProfileName=FARGATE_PROFILE_NAME
)
Expand All @@ -114,7 +114,7 @@ def test_poke_reached_target_state(self, mock_get_fargate_profile_state, setUp):
def test_poke_reached_pending_state(self, mock_get_fargate_profile_state, setUp, pending_state):
mock_get_fargate_profile_state.return_value = pending_state

assert not self.sensor.poke({})
assert self.sensor.poke({}) is False
mock_get_fargate_profile_state.assert_called_once_with(
clusterName=CLUSTER_NAME, fargateProfileName=FARGATE_PROFILE_NAME
)
Expand Down Expand Up @@ -153,7 +153,7 @@ def setUp(self):

@mock.patch.object(EksHook, "get_nodegroup_state", return_value=NodegroupStates.ACTIVE)
def test_poke_reached_target_state(self, mock_get_nodegroup_state, setUp):
assert self.sensor.poke({})
assert self.sensor.poke({}) is True
mock_get_nodegroup_state.assert_called_once_with(
clusterName=CLUSTER_NAME, nodegroupName=NODEGROUP_NAME
)
Expand All @@ -163,7 +163,7 @@ def test_poke_reached_target_state(self, mock_get_nodegroup_state, setUp):
def test_poke_reached_pending_state(self, mock_get_nodegroup_state, setUp, pending_state):
mock_get_nodegroup_state.return_value = pending_state

assert not self.sensor.poke({})
assert self.sensor.poke({}) is False
mock_get_nodegroup_state.assert_called_once_with(
clusterName=CLUSTER_NAME, nodegroupName=NODEGROUP_NAME
)
Expand Down
4 changes: 1 addition & 3 deletions tests/providers/amazon/aws/sensors/test_emr_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
# under the License.
from __future__ import annotations

import unittest

import pytest

from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -60,7 +58,7 @@ def failure_message_from_response(response):
return None


class TestEmrBaseSensor(unittest.TestCase):
class TestEmrBaseSensor:
def test_poke_returns_true_when_state_is_in_target_states(self):
operator = EmrBaseSensorSubclass(
task_id="test_task",
Expand Down
5 changes: 2 additions & 3 deletions tests/providers/amazon/aws/sensors/test_emr_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations

import unittest
from unittest import mock

import pytest
Expand All @@ -27,8 +26,8 @@
from airflow.providers.amazon.aws.sensors.emr import EmrContainerSensor


class TestEmrContainerSensor(unittest.TestCase):
def setUp(self):
class TestEmrContainerSensor:
def setup_method(self):
self.sensor = EmrContainerSensor(
task_id="test_emrcontainer_sensor",
virtual_cluster_id="vzwemreks",
Expand Down
10 changes: 5 additions & 5 deletions tests/providers/amazon/aws/sensors/test_emr_job_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

import datetime
import unittest
from unittest import mock
from unittest.mock import MagicMock, patch

import pytest
Expand Down Expand Up @@ -188,8 +188,8 @@
}


class TestEmrJobFlowSensor(unittest.TestCase):
def setUp(self):
class TestEmrJobFlowSensor:
def setup_method(self):
# Mock out the emr_client (moto has incorrect response)
self.mock_emr_client = MagicMock()

Expand All @@ -216,7 +216,7 @@ def test_execute_calls_with_the_job_flow_id_until_it_reaches_a_target_state(self
assert self.mock_emr_client.describe_cluster.call_count == 3

# make sure it was called with the job_flow_id
calls = [unittest.mock.call(ClusterId="j-8989898989")]
calls = [mock.call(ClusterId="j-8989898989")]
self.mock_emr_client.describe_cluster.assert_has_calls(calls)

def test_execute_calls_with_the_job_flow_id_until_it_reaches_failed_state_with_exception(self):
Expand Down Expand Up @@ -262,5 +262,5 @@ def test_different_target_states(self):
assert self.mock_emr_client.describe_cluster.call_count == 3

# make sure it was called with the job_flow_id
calls = [unittest.mock.call(ClusterId="j-8989898989")]
calls = [mock.call(ClusterId="j-8989898989")]
self.mock_emr_client.describe_cluster.assert_has_calls(calls)
10 changes: 5 additions & 5 deletions tests/providers/amazon/aws/sensors/test_emr_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
# under the License.
from __future__ import annotations

import unittest
from datetime import datetime
from unittest import mock
from unittest.mock import MagicMock, patch

import pytest
Expand Down Expand Up @@ -142,8 +142,8 @@
}


class TestEmrStepSensor(unittest.TestCase):
def setUp(self):
class TestEmrStepSensor:
def setup_method(self):
self.emr_client_mock = MagicMock()
self.sensor = EmrStepSensor(
task_id="test_task",
Expand All @@ -170,8 +170,8 @@ def test_step_completed(self):

assert self.emr_client_mock.describe_step.call_count == 2
calls = [
unittest.mock.call(ClusterId="j-8989898989", StepId="s-VK57YR1Z9Z5N"),
unittest.mock.call(ClusterId="j-8989898989", StepId="s-VK57YR1Z9Z5N"),
mock.call(ClusterId="j-8989898989", StepId="s-VK57YR1Z9Z5N"),
mock.call(ClusterId="j-8989898989", StepId="s-VK57YR1Z9Z5N"),
]
self.emr_client_mock.describe_step.assert_has_calls(calls)

Expand Down
Loading

0 comments on commit b726d8e

Please sign in to comment.