Skip to content

Commit

Permalink
bigfix: EMRHook Loop through paginated response to check for cluster …
Browse files Browse the repository at this point in the history
…id (apache#29732)

* bigfix: EMRHook  Loop through paginated response to check for cluster id 
---------

Co-authored-by: Nakkul Sreenivas <[email protected]>
Co-authored-by: Nakkul Sreenivas <[email protected]>
Co-authored-by: Tzu-ping Chung <[email protected]>
Co-authored-by: Niko Oliveira <[email protected]>
  • Loading branch information
5 people authored May 1, 2023
1 parent 607068f commit 9662fd8
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 4 deletions.
12 changes: 8 additions & 4 deletions airflow/providers/amazon/aws/hooks/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,15 @@ def get_cluster_id_by_name(self, emr_cluster_name: str, cluster_states: list[str
:param cluster_states: State(s) of cluster to find
:return: id of the EMR cluster
"""
response = self.get_conn().list_clusters(ClusterStates=cluster_states)

matching_clusters = list(
filter(lambda cluster: cluster["Name"] == emr_cluster_name, response["Clusters"])
response_iterator = (
self.get_conn().get_paginator("list_clusters").paginate(ClusterStates=cluster_states)
)
matching_clusters = [
cluster
for page in response_iterator
for cluster in page["Clusters"]
if cluster["Name"] == emr_cluster_name
]

if len(matching_clusters) == 1:
cluster_id = matching_clusters[0]["Id"]
Expand Down
47 changes: 47 additions & 0 deletions tests/providers/amazon/aws/hooks/test_emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import pytest
from moto import mock_emr

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.emr import EmrHook


Expand Down Expand Up @@ -198,6 +199,52 @@ def test_get_cluster_id_by_name(self):

assert no_match is None

@mock_emr
def test_get_cluster_id_by_name_duplicate(self):
"""
Test that we get an exception when there are duplicate clusters
"""
hook = EmrHook(aws_conn_id="aws_default", emr_conn_id="emr_default")

hook.create_job_flow({"Name": "test_cluster", "Instances": {"KeepJobFlowAliveWhenNoSteps": True}})

hook.create_job_flow({"Name": "test_cluster", "Instances": {"KeepJobFlowAliveWhenNoSteps": True}})

with pytest.raises(AirflowException):
hook.get_cluster_id_by_name("test_cluster", ["RUNNING", "WAITING", "BOOTSTRAPPING"])

@mock_emr
def test_get_cluster_id_by_name_pagination(self):
"""
Test that we can resolve cluster id by cluster name when there are
enough clusters to trigger pagination
"""
hook = EmrHook(aws_conn_id="aws_default", emr_conn_id="emr_default")

# Create enough clusters to trigger pagination
for index in range(51):
hook.create_job_flow(
{"Name": f"test_cluster_{index}", "Instances": {"KeepJobFlowAliveWhenNoSteps": True}}
)

# Fetch a cluster from the second page using the boto API
client = boto3.client("emr", region_name="us-east-1")
response_marker = client.list_clusters(ClusterStates=["RUNNING", "WAITING", "BOOTSTRAPPING"])[
"Marker"
]
second_page_cluster = client.list_clusters(
ClusterStates=["RUNNING", "WAITING", "BOOTSTRAPPING"], Marker=response_marker
)["Clusters"][0]

# Now that we have a cluster, fetch the id with the name
second_page_cluster_id = hook.get_cluster_id_by_name(
second_page_cluster["Name"], ["RUNNING", "WAITING", "BOOTSTRAPPING"]
)

# Assert that the id we got from the hook is the same as the one we got
# from the boto api
assert second_page_cluster_id == second_page_cluster["Id"]

@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrHook.conn")
def test_add_job_flow_steps_execution_role_arn(self, mock_conn):
"""
Expand Down

0 comments on commit 9662fd8

Please sign in to comment.