Skip to content

Commit

Permalink
feature: AWS - GlueJobOperator - job_poll_interval (apache#32147)
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelauv authored Jun 26, 2023
1 parent 51dbbaf commit cc87ae5
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 15 deletions.
8 changes: 4 additions & 4 deletions airflow/providers/amazon/aws/hooks/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ class GlueJobHook(AwsBaseHook):
- :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
"""

JOB_POLL_INTERVAL = 6 # polls job status after every JOB_POLL_INTERVAL seconds

class LogContinuationTokens:
"""Used to hold the continuation tokens when reading logs from both streams Glue Jobs write to."""

Expand All @@ -75,6 +73,7 @@ def __init__(
iam_role_name: str | None = None,
create_job_kwargs: dict | None = None,
update_config: bool = False,
job_poll_interval: int | float = 6,
*args,
**kwargs,
):
Expand All @@ -88,6 +87,7 @@ def __init__(
self.s3_glue_logs = "logs/glue-logs/"
self.create_job_kwargs = create_job_kwargs or {}
self.update_config = update_config
self.job_poll_interval = job_poll_interval

worker_type_exists = "WorkerType" in self.create_job_kwargs
num_workers_exists = "NumberOfWorkers" in self.create_job_kwargs
Expand Down Expand Up @@ -278,7 +278,7 @@ def job_completion(self, job_name: str, run_id: str, verbose: bool = False) -> d
if ret:
return ret
else:
time.sleep(self.JOB_POLL_INTERVAL)
time.sleep(self.job_poll_interval)

async def async_job_completion(self, job_name: str, run_id: str, verbose: bool = False) -> dict[str, str]:
"""
Expand All @@ -297,7 +297,7 @@ async def async_job_completion(self, job_name: str, run_id: str, verbose: bool =
if ret:
return ret
else:
await asyncio.sleep(self.JOB_POLL_INTERVAL)
await asyncio.sleep(self.job_poll_interval)

def _handle_state(
self,
Expand Down
4 changes: 4 additions & 0 deletions airflow/providers/amazon/aws/operators/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def __init__(
deferrable: bool = False,
verbose: bool = False,
update_config: bool = False,
job_poll_interval: int | float = 6,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -121,6 +122,7 @@ def __init__(
self.verbose = verbose
self.update_config = update_config
self.deferrable = deferrable
self.job_poll_interval = job_poll_interval

def execute(self, context: Context):
"""Execute AWS Glue Job from Airflow.
Expand Down Expand Up @@ -151,6 +153,7 @@ def execute(self, context: Context):
iam_role_name=self.iam_role_name,
create_job_kwargs=self.create_job_kwargs,
update_config=self.update_config,
job_poll_interval=self.job_poll_interval,
)
self.log.info(
"Initializing AWS Glue Job: %s. Wait for completion: %s",
Expand Down Expand Up @@ -181,6 +184,7 @@ def execute(self, context: Context):
run_id=glue_job_run["JobRunId"],
verbose=self.verbose,
aws_conn_id=self.aws_conn_id,
job_poll_interval=self.job_poll_interval,
),
method_name="execute_complete",
)
Expand Down
6 changes: 5 additions & 1 deletion airflow/providers/amazon/aws/triggers/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,14 @@ def __init__(
run_id: str,
verbose: bool,
aws_conn_id: str,
job_poll_interval: int | float,
):
super().__init__()
self.job_name = job_name
self.run_id = run_id
self.verbose = verbose
self.aws_conn_id = aws_conn_id
self.job_poll_interval = job_poll_interval

def serialize(self) -> tuple[str, dict[str, Any]]:
return (
Expand All @@ -54,10 +57,11 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"run_id": self.run_id,
"verbose": str(self.verbose),
"aws_conn_id": self.aws_conn_id,
"job_poll_interval": self.job_poll_interval,
},
)

async def run(self) -> AsyncIterator[TriggerEvent]:
hook = GlueJobHook(aws_conn_id=self.aws_conn_id)
hook = GlueJobHook(aws_conn_id=self.aws_conn_id, job_poll_interval=self.job_poll_interval)
await hook.async_job_completion(self.job_name, self.run_id, self.verbose)
yield TriggerEvent({"status": "success", "message": "Job done", "value": self.run_id})
12 changes: 4 additions & 8 deletions tests/providers/amazon/aws/hooks/test_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,7 @@ def test_print_job_logs_no_stream_yet(self, conn_mock: MagicMock, client_mock: M

@mock.patch.object(GlueJobHook, "get_job_state")
def test_job_completion_success(self, get_state_mock: MagicMock):
hook = GlueJobHook()
hook.JOB_POLL_INTERVAL = 0
hook = GlueJobHook(job_poll_interval=0)
get_state_mock.side_effect = [
"RUNNING",
"RUNNING",
Expand All @@ -368,8 +367,7 @@ def test_job_completion_success(self, get_state_mock: MagicMock):

@mock.patch.object(GlueJobHook, "get_job_state")
def test_job_completion_failure(self, get_state_mock: MagicMock):
hook = GlueJobHook()
hook.JOB_POLL_INTERVAL = 0
hook = GlueJobHook(job_poll_interval=0)
get_state_mock.side_effect = [
"RUNNING",
"RUNNING",
Expand All @@ -384,8 +382,7 @@ def test_job_completion_failure(self, get_state_mock: MagicMock):
@pytest.mark.asyncio
@mock.patch.object(GlueJobHook, "async_get_job_state")
async def test_async_job_completion_success(self, get_state_mock: MagicMock):
hook = GlueJobHook()
hook.JOB_POLL_INTERVAL = 0
hook = GlueJobHook(job_poll_interval=0)
get_state_mock.side_effect = [
"RUNNING",
"RUNNING",
Expand All @@ -400,8 +397,7 @@ async def test_async_job_completion_success(self, get_state_mock: MagicMock):
@pytest.mark.asyncio
@mock.patch.object(GlueJobHook, "async_get_job_state")
async def test_async_job_completion_failure(self, get_state_mock: MagicMock):
hook = GlueJobHook()
hook.JOB_POLL_INTERVAL = 0
hook = GlueJobHook(job_poll_interval=0)
get_state_mock.side_effect = [
"RUNNING",
"RUNNING",
Expand Down
4 changes: 2 additions & 2 deletions tests/providers/amazon/aws/triggers/test_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ class TestGlueJobTrigger:
@pytest.mark.asyncio
@mock.patch.object(GlueJobHook, "async_get_job_state")
async def test_wait_job(self, get_state_mock: mock.MagicMock):
GlueJobHook.JOB_POLL_INTERVAL = 0.1
trigger = GlueJobCompleteTrigger(
job_name="job_name",
run_id="JobRunId",
verbose=False,
aws_conn_id="aws_conn_id",
job_poll_interval=0.1,
)
get_state_mock.side_effect = [
"RUNNING",
Expand All @@ -52,12 +52,12 @@ async def test_wait_job(self, get_state_mock: mock.MagicMock):
@pytest.mark.asyncio
@mock.patch.object(GlueJobHook, "async_get_job_state")
async def test_wait_job_failed(self, get_state_mock: mock.MagicMock):
GlueJobHook.JOB_POLL_INTERVAL = 0.1
trigger = GlueJobCompleteTrigger(
job_name="job_name",
run_id="JobRunId",
verbose=False,
aws_conn_id="aws_conn_id",
job_poll_interval=0.1,
)
get_state_mock.side_effect = [
"RUNNING",
Expand Down

0 comments on commit cc87ae5

Please sign in to comment.