Skip to content

Commit

Permalink
AIP-72: Handle External update TI state in Supervisor (apache#44406)
Browse files Browse the repository at this point in the history
- Updated logic to handle externally updated TI state in Supervisor. This states could have been externally changed via UI, CLI, API etc
- Replaced `FASTEST_HEARTBEAT_INTERVAL` and `SLOWEST_HEARTBEAT_INTERVAL` with `MIN_HEARTBEAT_INTERVAL` and `HEARTBEAT_THRESHOLD` for better clarity

This is part of my efforts to port LocalTaskJob tests to Supervisor: apache#44356.

This ports over `TestLocalTaskJob.test_mark_{success,failure}_no_kill`.

This PR also allows retrying heartbeats:

- Added `_last_successful_heartbeat` and `_last_heartbeat_attempt` for better separation of tracking successful heartbeats and retries.
- `MIN_HEARTBEAT_INTERVAL` is now respected between heartbeat attempts, even after failures.
-  The num of retries is configurable via `MAX_FAILED_HEARTBEATS`
  • Loading branch information
kaxil authored Nov 28, 2024
1 parent 1f18458 commit 6e3a25e
Show file tree
Hide file tree
Showing 3 changed files with 242 additions and 42 deletions.
123 changes: 93 additions & 30 deletions task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from collections.abc import Generator
from contextlib import suppress
from datetime import datetime, timezone
from http import HTTPStatus
from socket import socket, socketpair
from typing import TYPE_CHECKING, BinaryIO, Callable, ClassVar, Literal, NoReturn, TextIO, cast, overload
from uuid import UUID
Expand All @@ -42,7 +43,7 @@
import structlog
from pydantic import TypeAdapter

from airflow.sdk.api.client import Client
from airflow.sdk.api.client import Client, ServerResponseError
from airflow.sdk.api.datamodels._generated import IntermediateTIState, TaskInstance, TerminalTIState
from airflow.sdk.execution_time.comms import (
DeferTask,
Expand All @@ -54,6 +55,8 @@
)

if TYPE_CHECKING:
from selectors import SelectorKey

from structlog.typing import FilteringBoundLogger, WrappedLogger


Expand All @@ -62,9 +65,12 @@
log: FilteringBoundLogger = structlog.get_logger(logger_name="supervisor")

# TODO: Pull this from config
SLOWEST_HEARTBEAT_INTERVAL: int = 30
# (previously `[scheduler] local_task_job_heartbeat_sec` with the following as fallback if it is 0:
# `[scheduler] scheduler_zombie_task_threshold`)
HEARTBEAT_THRESHOLD: int = 30
# Don't heartbeat more often than this
FASTEST_HEARTBEAT_INTERVAL: int = 5
MIN_HEARTBEAT_INTERVAL: int = 5
MAX_FAILED_HEARTBEATS: int = 3


@overload
Expand Down Expand Up @@ -265,7 +271,13 @@ class WatchedSubprocess:
_terminal_state: str | None = None
_final_state: str | None = None

_last_heartbeat: float = 0
_last_successful_heartbeat: float = attrs.field(default=0, init=False)
_last_heartbeat_attempt: float = attrs.field(default=0, init=False)

# After the failure of a heartbeat, we'll increment this counter. If it reaches `MAX_FAILED_HEARTBEATS`, we
# will kill the process. This is to handle temporary network issues etc. ensuring that the process
# does not hang around forever.
failed_heartbeats: int = attrs.field(default=0, init=False)

selector: selectors.BaseSelector = attrs.field(factory=selectors.DefaultSelector)

Expand Down Expand Up @@ -320,7 +332,7 @@ def start(
# reason)
try:
client.task_instances.start(ti.id, pid, datetime.now(tz=timezone.utc))
proc._last_heartbeat = time.monotonic()
proc._last_successful_heartbeat = time.monotonic()
except Exception:
# On any error kill that subprocess!
proc.kill(signal.SIGKILL)
Expand Down Expand Up @@ -423,38 +435,55 @@ def _monitor_subprocess(self):
This function:
- Polls the subprocess for output
- Sends heartbeats to the client to keep the task alive
- Checks if the subprocess has exited
- Waits for activity on file objects (e.g., subprocess stdout, stderr, logs, requests) using the selector.
- Processes events triggered on the monitored file objects, such as data availability or EOF.
- Sends heartbeats to ensure the process is alive and checks if the subprocess has exited.
"""
# Until we have a selector for the process, don't poll for more than 10s, just in case it exists but
# doesn't produce any output
max_poll_interval = 10

while self._exit_code is None or len(self.selector.get_map()):
last_heartbeat_ago = time.monotonic() - self._last_heartbeat
last_heartbeat_ago = time.monotonic() - self._last_successful_heartbeat
# Monitor the task to see if it's done. Wait in a syscall (`select`) for as long as possible
# so we notice the subprocess finishing as quick as we can.
max_wait_time = max(
0, # Make sure this value is never negative,
min(
# Ensure we heartbeat _at most_ 75% through time the zombie threshold time
SLOWEST_HEARTBEAT_INTERVAL - last_heartbeat_ago * 0.75,
max_poll_interval,
HEARTBEAT_THRESHOLD - last_heartbeat_ago * 0.75,
MIN_HEARTBEAT_INTERVAL,
),
)
# Block until events are ready or the timeout is reached
# This listens for activity (e.g., subprocess output) on registered file objects
events = self.selector.select(timeout=max_wait_time)
for key, _ in events:
socket_handler = key.data
need_more = socket_handler(key.fileobj)

if not need_more:
self.selector.unregister(key.fileobj)
key.fileobj.close() # type: ignore[union-attr]
self._process_file_object_events(events)

self._check_subprocess_exit()
self._send_heartbeat_if_needed()

def _process_file_object_events(self, events: list[tuple[SelectorKey, int]]):
"""
Process selector events by invoking handlers for each file object.
For each file object event, this method retrieves the associated handler and processes
the event. If the handler indicates that the file object no longer needs
monitoring (e.g., EOF or closed), the file object is unregistered and closed.
"""
for key, _ in events:
# Retrieve the handler responsible for processing this file object (e.g., stdout, stderr)
socket_handler = key.data

# Example of handler behavior:
# If the subprocess writes "Hello, World!" to stdout:
# - `socket_handler` reads and processes the message.
# - If EOF is reached, the handler returns False to signal no more reads are expected.
need_more = socket_handler(key.fileobj)

# If the handler signals that the file object is no longer needed (EOF, closed, etc.)
# unregister it from the selector to stop monitoring; `wait()` blocks until all selectors
# are removed.
if not need_more:
self.selector.unregister(key.fileobj)
key.fileobj.close() # type: ignore[union-attr]

def _check_subprocess_exit(self):
"""Check if the subprocess has exited."""
if self._exit_code is None:
Expand All @@ -466,14 +495,48 @@ def _check_subprocess_exit(self):

def _send_heartbeat_if_needed(self):
"""Send a heartbeat to the client if heartbeat interval has passed."""
if time.monotonic() - self._last_heartbeat >= FASTEST_HEARTBEAT_INTERVAL:
try:
self.client.task_instances.heartbeat(self.ti_id, pid=self._process.pid)
self._last_heartbeat = time.monotonic()
except Exception:
log.warning("Failed to send heartbeat", exc_info=True)
# TODO: If we couldn't heartbeat for X times the interval, kill ourselves
pass
# Respect the minimum interval between heartbeat attempts
if (time.monotonic() - self._last_heartbeat_attempt) < MIN_HEARTBEAT_INTERVAL:
return

self._last_heartbeat_attempt = time.monotonic()
try:
self.client.task_instances.heartbeat(self.ti_id, pid=self._process.pid)
# Update the last heartbeat time on success
self._last_successful_heartbeat = time.monotonic()

# Reset the counter on success
self.failed_heartbeats = 0
except ServerResponseError as e:
if e.response.status_code in {HTTPStatus.NOT_FOUND, HTTPStatus.CONFLICT}:
log.error(
"Server indicated the task shouldn't be running anymore",
detail=e.detail,
status_code=e.response.status_code,
)
self.kill(signal.SIGTERM)
else:
# If we get any other error, we'll just log it and try again next time
self._handle_heartbeat_failures()
except Exception:
self._handle_heartbeat_failures()

def _handle_heartbeat_failures(self):
"""Increment the failed heartbeats counter and kill the process if too many failures."""
self.failed_heartbeats += 1
log.warning(
"Failed to send heartbeat. Will be retried",
failed_heartbeats=self.failed_heartbeats,
ti_id=self.ti_id,
max_retries=MAX_FAILED_HEARTBEATS,
exc_info=True,
)
# If we've failed to heartbeat too many times, kill the process
if self.failed_heartbeats >= MAX_FAILED_HEARTBEATS:
log.error(
"Too many failed heartbeats; terminating process", failed_heartbeats=self.failed_heartbeats
)
self.kill(signal.SIGTERM)

@property
def final_state(self):
Expand Down
9 changes: 8 additions & 1 deletion task_sdk/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import logging
import os
from pathlib import Path
from typing import TYPE_CHECKING, NoReturn
Expand Down Expand Up @@ -72,7 +73,7 @@ def test_dags_dir():


@pytest.fixture
def captured_logs():
def captured_logs(request):
import structlog

from airflow.sdk.log import configure_logging, reset_logging
Expand All @@ -81,6 +82,12 @@ def captured_logs():
reset_logging()
configure_logging(enable_pretty_log=False)

# Get log level from test parameter, defaulting to INFO if not provided
log_level = getattr(request, "param", logging.INFO)

# We want to capture all logs, but we don't want to see them in the test output
structlog.configure(wrapper_class=structlog.make_filtering_bound_logger(log_level))

# But we need to replace remove the last processor (the one that turns JSON into text, as we want the
# event dict for tests)
cur_processors = structlog.get_config()["processors"]
Expand Down
Loading

0 comments on commit 6e3a25e

Please sign in to comment.