Skip to content

Commit

Permalink
Modify KPO to log container log periodically (apache#37279)
Browse files Browse the repository at this point in the history
The current state of the KubernetesPodOperator (KPO)
only prints container logs at the end of task execution.
While this is sufficient for short-running tasks,
it becomes less user-friendly when the container runs for an extended period.
This PR enhances the KPO by modifying the trigger and operator
to fetch container logs periodically
making it possible to monitor the task's progress in the Airflow task UI.

a new parameter has been introduced to the operator:

logging_interval: This parameter specifies the maximum time,
in seconds, that the task should remain deferred before resuming to fetch the latest logs.
  • Loading branch information
pankajastro authored Feb 12, 2024
1 parent ed346c7 commit 053485b
Show file tree
Hide file tree
Showing 8 changed files with 378 additions and 280 deletions.
92 changes: 87 additions & 5 deletions airflow/providers/cncf/kubernetes/operators/pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@
from urllib3.exceptions import HTTPError

from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException
from airflow.exceptions import (
AirflowException,
AirflowProviderDeprecationWarning,
AirflowSkipException,
TaskDeferred,
)
from airflow.models import BaseOperator
from airflow.providers.cncf.kubernetes import pod_generator
from airflow.providers.cncf.kubernetes.backcompat.backwards_compat_converters import (
Expand Down Expand Up @@ -63,7 +68,9 @@
EMPTY_XCOM_RESULT,
OnFinishAction,
PodLaunchFailedException,
PodLaunchTimeoutException,
PodManager,
PodNotFoundException,
PodOperatorHookProtocol,
PodPhase,
container_is_succeeded,
Expand All @@ -77,6 +84,7 @@

if TYPE_CHECKING:
import jinja2
from pendulum import DateTime
from typing_extensions import Literal

from airflow.providers.cncf.kubernetes.secret import Secret
Expand Down Expand Up @@ -203,6 +211,9 @@ class KubernetesPodOperator(BaseOperator):
of KubernetesPodOperator.
:param progress_callback: Callback function for receiving k8s container logs.
`progress_callback` is deprecated, please use :param `callbacks` instead.
:param logging_interval: max time in seconds that task should be in deferred state before
resuming to fetch the latest logs. If ``None``, then the task will remain in deferred state until pod
is done, and no logs will be visible until that time.
"""

# !!! Changes in KubernetesPodOperator's arguments should be also reflected in !!!
Expand Down Expand Up @@ -297,6 +308,7 @@ def __init__(
active_deadline_seconds: int | None = None,
callbacks: type[KubernetesPodOperatorCallback] | None = None,
progress_callback: Callable[[str], None] | None = None,
logging_interval: int | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand Down Expand Up @@ -387,6 +399,7 @@ def __init__(
self.is_delete_operator_pod = self.on_finish_action == OnFinishAction.DELETE_POD
self.termination_message_policy = termination_message_policy
self.active_deadline_seconds = active_deadline_seconds
self.logging_interval = logging_interval

self._config_dict: dict | None = None # TODO: remove it when removing convert_config_file_to_dict
self._progress_callback = progress_callback
Expand Down Expand Up @@ -641,13 +654,13 @@ def execute_async(self, context: Context):

self.invoke_defer_method()

def invoke_defer_method(self):
def invoke_defer_method(self, last_log_time: DateTime | None = None):
"""Redefine triggers which are being used in child classes."""
trigger_start_time = utcnow()
self.defer(
trigger=KubernetesPodTrigger(
pod_name=self.pod.metadata.name,
pod_namespace=self.pod.metadata.namespace,
pod_name=self.pod.metadata.name, # type: ignore[union-attr]
pod_namespace=self.pod.metadata.namespace, # type: ignore[union-attr]
trigger_start_time=trigger_start_time,
kubernetes_conn_id=self.kubernetes_conn_id,
cluster_context=self.cluster_context,
Expand All @@ -659,10 +672,79 @@ def invoke_defer_method(self):
startup_check_interval=self.startup_check_interval_seconds,
base_container_name=self.base_container_name,
on_finish_action=self.on_finish_action.value,
last_log_time=last_log_time,
logging_interval=self.logging_interval,
),
method_name="execute_complete",
method_name="trigger_reentry",
)

@staticmethod
def raise_for_trigger_status(event: dict[str, Any]) -> None:
"""Raise exception if pod is not in expected state."""
if event["status"] == "error":
error_type = event["error_type"]
description = event["description"]
if error_type == "PodLaunchTimeoutException":
raise PodLaunchTimeoutException(description)
else:
raise AirflowException(description)

def trigger_reentry(self, context: Context, event: dict[str, Any]) -> Any:
"""
Point of re-entry from trigger.
If ``logging_interval`` is None, then at this point the pod should be done and we'll just fetch
the logs and exit.
If ``logging_interval`` is not None, it could be that the pod is still running and we'll just
grab the latest logs and defer back to the trigger again.
"""
remote_pod = None
try:
self.pod_request_obj = self.build_pod_request_obj(context)
self.pod = self.find_pod(
namespace=self.namespace or self.pod_request_obj.metadata.namespace,
context=context,
)

# we try to find pod before possibly raising so that on_kill will have `pod` attr
self.raise_for_trigger_status(event)

if not self.pod:
raise PodNotFoundException("Could not find pod after resuming from deferral")

if self.get_logs:
last_log_time = event and event.get("last_log_time")
if last_log_time:
self.log.info("Resuming logs read from time %r", last_log_time)
pod_log_status = self.pod_manager.fetch_container_logs(
pod=self.pod,
container_name=self.BASE_CONTAINER_NAME,
follow=self.logging_interval is None,
since_time=last_log_time,
)
if pod_log_status.running:
self.log.info("Container still running; deferring again.")
self.invoke_defer_method(pod_log_status.last_log_time)

if self.do_xcom_push:
result = self.extract_xcom(pod=self.pod)
remote_pod = self.pod_manager.await_pod_completion(self.pod)
except TaskDeferred:
raise
except Exception:
self.cleanup(
pod=self.pod or self.pod_request_obj,
remote_pod=remote_pod,
)
raise
self.cleanup(
pod=self.pod or self.pod_request_obj,
remote_pod=remote_pod,
)
if self.do_xcom_push:
return result

def execute_complete(self, context: Context, event: dict, **kwargs):
self.log.debug("Triggered with event: %s", event)
pod = None
Expand Down
157 changes: 72 additions & 85 deletions airflow/providers/cncf/kubernetes/triggers/pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,24 @@
import datetime
import traceback
import warnings
from asyncio import CancelledError
from enum import Enum
from functools import cached_property
from typing import TYPE_CHECKING, Any, AsyncIterator

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.cncf.kubernetes.hooks.kubernetes import AsyncKubernetesHook
from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction, PodPhase
from airflow.providers.cncf.kubernetes.utils.pod_manager import (
OnFinishAction,
PodLaunchTimeoutException,
PodPhase,
container_is_running,
)
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils import timezone

if TYPE_CHECKING:
from kubernetes_asyncio.client.models import V1Pod
from pendulum import DateTime


class ContainerState(str, Enum):
Expand Down Expand Up @@ -70,6 +76,9 @@ class KubernetesPodTrigger(BaseTrigger):
state, or the execution is interrupted. If True (default), delete the
pod; if False, leave the pod.
Deprecated - use `on_finish_action` instead.
:param logging_interval: number of seconds to wait before kicking it back to
the operator to print latest logs. If ``None`` will wait until container done.
:param last_log_time: where to resume logs from
"""

def __init__(
Expand All @@ -88,6 +97,8 @@ def __init__(
startup_check_interval: int = 1,
on_finish_action: str = "delete_pod",
should_delete_pod: bool | None = None,
last_log_time: DateTime | None = None,
logging_interval: int | None = None,
):
super().__init__()
self.pod_name = pod_name
Expand All @@ -102,6 +113,8 @@ def __init__(
self.get_logs = get_logs
self.startup_timeout = startup_timeout
self.startup_check_interval = startup_check_interval
self.last_log_time = last_log_time
self.logging_interval = logging_interval

if should_delete_pod is not None:
warnings.warn(
Expand Down Expand Up @@ -137,104 +150,78 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"trigger_start_time": self.trigger_start_time,
"should_delete_pod": self.should_delete_pod,
"on_finish_action": self.on_finish_action.value,
"last_log_time": self.last_log_time,
"logging_interval": self.logging_interval,
},
)

async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
"""Get current pod status and yield a TriggerEvent."""
self.log.info("Checking pod %r in namespace %r.", self.pod_name, self.pod_namespace)
try:
while True:
pod = await self.hook.get_pod(
name=self.pod_name,
namespace=self.pod_namespace,
state = await self._wait_for_pod_start()
if state in PodPhase.terminal_states:
event = TriggerEvent(
{"status": "done", "namespace": self.pod_namespace, "pod_name": self.pod_name}
)

pod_status = pod.status.phase
self.log.debug("Pod %s status: %s", self.pod_name, pod_status)

container_state = self.define_container_state(pod)
self.log.debug("Container %s status: %s", self.base_container_name, container_state)

if container_state == ContainerState.TERMINATED:
yield TriggerEvent(
{
"name": self.pod_name,
"namespace": self.pod_namespace,
"status": "success",
"message": "All containers inside pod have started successfully.",
}
)
return
elif self.should_wait(pod_phase=pod_status, container_state=container_state):
self.log.info("Container is not completed and still working.")

if pod_status == PodPhase.PENDING and container_state != ContainerState.RUNNING:
delta = datetime.datetime.now(tz=datetime.timezone.utc) - self.trigger_start_time
if delta.total_seconds() >= self.startup_timeout:
message = (
f"Pod took longer than {self.startup_timeout} seconds to start. "
"Check the pod events in kubernetes to determine why."
)
yield TriggerEvent(
{
"name": self.pod_name,
"namespace": self.pod_namespace,
"status": "timeout",
"message": message,
}
)
return
else:
self.log.info("Sleeping for %s seconds.", self.startup_check_interval)
await asyncio.sleep(self.startup_check_interval)
else:
self.log.info("Sleeping for %s seconds.", self.poll_interval)
await asyncio.sleep(self.poll_interval)
else:
yield TriggerEvent(
{
"name": self.pod_name,
"namespace": self.pod_namespace,
"status": "failed",
"message": pod.status.message,
}
)
return
except CancelledError:
# That means that task was marked as failed
if self.get_logs:
self.log.info("Outputting container logs...")
await self.hook.read_logs(
name=self.pod_name,
namespace=self.pod_namespace,
)
if self.on_finish_action == OnFinishAction.DELETE_POD:
self.log.info("Deleting pod...")
await self.hook.delete_pod(
name=self.pod_name,
namespace=self.pod_namespace,
)
yield TriggerEvent(
{
"name": self.pod_name,
"namespace": self.pod_namespace,
"status": "cancelled",
"message": "Pod execution was cancelled",
}
)
else:
event = await self._wait_for_container_completion()
yield event
except Exception as e:
self.log.exception("Exception occurred while checking pod phase:")
description = self._format_exception_description(e)
yield TriggerEvent(
{
"name": self.pod_name,
"namespace": self.pod_namespace,
"status": "error",
"message": str(e),
"stack_trace": traceback.format_exc(),
"error_type": e.__class__.__name__,
"description": description,
}
)

def _format_exception_description(self, exc: Exception) -> Any:
if isinstance(exc, PodLaunchTimeoutException):
return exc.args[0]

description = f"Trigger {self.__class__.__name__} failed with exception {exc.__class__.__name__}."
message = exc.args and exc.args[0] or ""
if message:
description += f"\ntrigger exception message: {message}"
curr_traceback = traceback.format_exc()
description += f"\ntrigger traceback:\n{curr_traceback}"
return description

async def _wait_for_pod_start(self) -> Any:
"""Loops until pod phase leaves ``PENDING`` If timeout is reached, throws error."""
start_time = timezone.utcnow()
timeout_end = start_time + datetime.timedelta(seconds=self.startup_timeout)
while timeout_end > timezone.utcnow():
pod = await self.hook.get_pod(self.pod_name, self.pod_namespace)
if not pod.status.phase == "Pending":
return pod.status.phase
self.log.info("Still waiting for pod to start. The pod state is %s", pod.status.phase)
await asyncio.sleep(self.poll_interval)
raise PodLaunchTimeoutException("Pod did not leave 'Pending' phase within specified timeout")

async def _wait_for_container_completion(self) -> TriggerEvent:
"""
Wait for container completion.
Waits until container is no longer in running state. If trigger is configured with a logging period,
then will emit an event to resume the task for the purpose of fetching more logs.
"""
time_begin = timezone.utcnow()
time_get_more_logs = None
if self.logging_interval is not None:
time_get_more_logs = time_begin + datetime.timedelta(seconds=self.logging_interval)
while True:
pod = await self.hook.get_pod(self.pod_name, self.pod_namespace)
if not container_is_running(pod=pod, container_name=self.base_container_name):
return TriggerEvent(
{"status": "done", "namespace": self.pod_namespace, "pod_name": self.pod_name}
)
if time_get_more_logs and timezone.utcnow() > time_get_more_logs:
return TriggerEvent({"status": "running", "last_log_time": self.last_log_time})
await asyncio.sleep(self.poll_interval)

def _get_async_hook(self) -> AsyncKubernetesHook:
# TODO: Remove this method when the min version of kubernetes provider is 7.12.0 in Google provider.
return AsyncKubernetesHook(
Expand Down
8 changes: 8 additions & 0 deletions airflow/providers/cncf/kubernetes/utils/pod_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,14 @@ def get_container_termination_message(pod: V1Pod, container_name: str):
return container_status.state.terminated.message if container_status else None


class PodLaunchTimeoutException(AirflowException):
"""When pod does not leave the ``Pending`` phase within specified timeout."""


class PodNotFoundException(AirflowException):
"""Expected pod does not exist in kube-api."""


class PodLogsConsumer:
"""
Responsible for pulling pod logs from a stream with checking a container status before reading data.
Expand Down
Loading

0 comments on commit 053485b

Please sign in to comment.