Skip to content

Commit

Permalink
Fix poll_interval in GKEJobTrigger (apache#41712)
Browse files Browse the repository at this point in the history
  • Loading branch information
gopidesupavan authored Sep 1, 2024
1 parent 04217f1 commit 2823acd
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 5 deletions.
24 changes: 21 additions & 3 deletions airflow/providers/google/cloud/triggers/kubernetes_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
from typing import TYPE_CHECKING, Any, AsyncIterator, Sequence

from google.cloud.container_v1.types import Operation
from packaging.version import parse as parse_version

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.cncf.kubernetes.triggers.pod import KubernetesPodTrigger
from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction, PodManager
from airflow.providers.cncf.kubernetes.utils.xcom_sidecar import PodDefaults
Expand All @@ -33,6 +34,7 @@
GKEKubernetesAsyncHook,
GKEKubernetesHook,
)
from airflow.providers_manager import ProvidersManager
from airflow.triggers.base import BaseTrigger, TriggerEvent

if TYPE_CHECKING:
Expand Down Expand Up @@ -305,19 +307,35 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
if self.get_logs or self.do_xcom_push:
pod = await self.hook.get_pod(name=self.pod_name, namespace=self.pod_namespace)
if self.do_xcom_push:
kubernetes_provider = ProvidersManager().providers["apache-airflow-providers-cncf-kubernetes"]
kubernetes_provider_name = kubernetes_provider.data["package-name"]
kubernetes_provider_version = kubernetes_provider.version
min_version = "8.4.1"
if parse_version(kubernetes_provider_version) < parse_version(min_version):
raise AirflowException(
"You are trying to use do_xcom_push in `GKEStartJobOperator` with the provider "
f"package {kubernetes_provider_name}=={kubernetes_provider_version} which doesn't "
f"support this feature. Please upgrade it to version higher than or equal to {min_version}."
)
await self.hook.wait_until_container_complete(
name=self.pod_name, namespace=self.pod_namespace, container_name=self.base_container_name
name=self.pod_name,
namespace=self.pod_namespace,
container_name=self.base_container_name,
poll_interval=self.poll_interval,
)
self.log.info("Checking if xcom sidecar container is started.")
await self.hook.wait_until_container_started(
name=self.pod_name,
namespace=self.pod_namespace,
container_name=PodDefaults.SIDECAR_CONTAINER_NAME,
poll_interval=self.poll_interval,
)
self.log.info("Extracting result from xcom sidecar container.")
loop = asyncio.get_running_loop()
xcom_result = await loop.run_in_executor(None, self.pod_manager.extract_xcom, pod)
job: V1Job = await self.hook.wait_until_job_complete(name=self.job_name, namespace=self.job_namespace)
job: V1Job = await self.hook.wait_until_job_complete(
name=self.job_name, namespace=self.job_namespace, poll_interval=self.poll_interval
)
job_dict = job.to_dict()
error_message = self.hook.is_job_failed(job=job)
status = "error" if error_message else "success"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,9 @@ async def test_run_success(self, mock_hook, job_trigger):

event_actual = await job_trigger.run().asend(None)

mock_hook.wait_until_job_complete.assert_called_once_with(name=JOB_NAME, namespace=NAMESPACE)
mock_hook.wait_until_job_complete.assert_called_once_with(
name=JOB_NAME, namespace=NAMESPACE, poll_interval=POLL_INTERVAL
)
mock_job.to_dict.assert_called_once()
mock_is_job_failed.assert_called_once_with(job=mock_job)
assert event_actual == TriggerEvent(
Expand Down Expand Up @@ -544,7 +546,9 @@ async def test_run_fail(self, mock_hook, job_trigger):

event_actual = await job_trigger.run().asend(None)

mock_hook.wait_until_job_complete.assert_called_once_with(name=JOB_NAME, namespace=NAMESPACE)
mock_hook.wait_until_job_complete.assert_called_once_with(
name=JOB_NAME, namespace=NAMESPACE, poll_interval=POLL_INTERVAL
)
mock_job.to_dict.assert_called_once()
mock_is_job_failed.assert_called_once_with(job=mock_job)
assert event_actual == TriggerEvent(
Expand Down

0 comments on commit 2823acd

Please sign in to comment.