Skip to content

Commit

Permalink
AWS logs. Exit fast when 3 consecutive responses are returned from AW…
Browse files Browse the repository at this point in the history
…S Cloudwatch logs (apache#30756)

* AWS logs. Exit fast when 3 consecutive responses are returned from AWS Cloudwatch logs

---------

Co-authored-by: Niko Oliveira <[email protected]>
  • Loading branch information
vincbeck and o-nikolas authored Apr 21, 2023
1 parent ac6ef75 commit 7e01c09
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 36 deletions.
26 changes: 22 additions & 4 deletions airflow/providers/amazon/aws/hooks/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@

from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook

# Guidance received from the AWS team regarding the correct way to check for the end of a stream is that the
# value of the nextForwardToken is the same in subsequent calls.
# The issue with this approach is, it can take a huge amount of time (e.g. 20 seconds) to retrieve logs using
# this approach. As an intermediate solution, we decided to stop fetching logs if 3 consecutive responses
# are empty.
# See PR https://github.com/apache/airflow/pull/20814
NUM_CONSECUTIVE_EMPTY_RESPONSE_EXIT_THRESHOLD = 3


class AwsLogsHook(AwsBaseHook):
"""
Expand Down Expand Up @@ -69,14 +77,15 @@ def get_log_events(
| 'message' (str): The log event data.
| 'ingestionTime' (int): The time in milliseconds the event was ingested.
"""
num_consecutive_empty_response = 0
next_token = None
while True:
if next_token is not None:
token_arg: dict[str, str] | None = {"nextToken": next_token}
token_arg: dict[str, str] = {"nextToken": next_token}
else:
token_arg = {}

response = self.get_conn().get_log_events(
response = self.conn.get_log_events(
logGroupName=log_group,
logStreamName=log_stream_name,
startTime=start_time,
Expand All @@ -96,7 +105,16 @@ def get_log_events(

yield from events

if next_token != response["nextForwardToken"]:
next_token = response["nextForwardToken"]
if not event_count:
num_consecutive_empty_response += 1
if num_consecutive_empty_response >= NUM_CONSECUTIVE_EMPTY_RESPONSE_EXIT_THRESHOLD:
# Exit if there are more than NUM_CONSECUTIVE_EMPTY_RESPONSE_EXIT_THRESHOLD consecutive
# empty responses
return
elif next_token != response["nextForwardToken"]:
num_consecutive_empty_response = 0
else:
# Exit if the value of nextForwardToken is same in subsequent calls
return

next_token = response["nextForwardToken"]
94 changes: 65 additions & 29 deletions tests/providers/amazon/aws/hooks/test_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,46 +17,82 @@
# under the License.
from __future__ import annotations

import time
from unittest import mock
from unittest.mock import patch

import pytest
from moto import mock_logs

from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook


@mock_logs
class TestAwsLogsHook:
def test_get_conn_returns_a_boto3_connection(self):
hook = AwsLogsHook(aws_conn_id="aws_default", region_name="us-east-1")
assert hook.get_conn() is not None

def test_get_log_events(self):
# moto.logs does not support proper pagination so we cannot test that yet
# https://github.com/spulec/moto/issues/2259
# ToDo: seems like mock.logs support since https://github.com/spulec/moto/pull/2361

log_group_name = "example-group"
log_stream_name = "example-log-stream"
@pytest.mark.parametrize(
"get_log_events_response, num_skip_events, expected_num_events",
[
# 3 empty responses with different tokens
(
[
{"nextForwardToken": "1", "events": []},
{"nextForwardToken": "2", "events": []},
{"nextForwardToken": "3", "events": []},
],
0,
0,
),
# 2 events on the second response with same token
(
[
{"nextForwardToken": "", "events": []},
{"nextForwardToken": "", "events": [{}, {}]},
],
0,
2,
),
# Different tokens, 2 events on the second response then 3 empty responses
(
[
{"nextForwardToken": "1", "events": []},
{"nextForwardToken": "2", "events": [{}, {}]},
{"nextForwardToken": "3", "events": []},
{"nextForwardToken": "4", "events": []},
{"nextForwardToken": "5", "events": []},
# This one is ignored
{"nextForwardToken": "6", "events": [{}, {}]},
],
0,
2,
),
# 2 events on the second response, then 2 empty responses, then 2 consecutive responses with
# 2 events with the same token
(
[
{"nextForwardToken": "1", "events": []},
{"nextForwardToken": "2", "events": [{}, {}]},
{"nextForwardToken": "3", "events": []},
{"nextForwardToken": "4", "events": []},
{"nextForwardToken": "6", "events": [{}, {}]},
{"nextForwardToken": "6", "events": [{}, {}]},
# This one is ignored
{"nextForwardToken": "6", "events": [{}, {}]},
],
0,
6,
),
],
)
@patch("airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.conn", new_callable=mock.PropertyMock)
def test_get_log_events(self, mock_conn, get_log_events_response, num_skip_events, expected_num_events):
mock_conn().get_log_events.side_effect = get_log_events_response

hook = AwsLogsHook(aws_conn_id="aws_default", region_name="us-east-1")

# First we create some log events
conn = hook.get_conn()
conn.create_log_group(logGroupName=log_group_name)
conn.create_log_stream(logGroupName=log_group_name, logStreamName=log_stream_name)

input_events = [{"timestamp": int(time.time()) * 1000, "message": "Test Message 1"}]

conn.put_log_events(
logGroupName=log_group_name, logStreamName=log_stream_name, logEvents=input_events
events = hook.get_log_events(
log_group="example-group",
log_stream_name="example-log-stream",
skip=num_skip_events,
)

events = hook.get_log_events(log_group=log_group_name, log_stream_name=log_stream_name)

# Iterate through entire generator
events = list(events)
count = len(events)

assert count == 1
assert events[0]["timestamp"] == input_events[0]["timestamp"]
assert events[0]["message"] == input_events[0]["message"]
assert len(events) == expected_num_events
6 changes: 3 additions & 3 deletions tests/providers/amazon/aws/hooks/test_sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ def test_secondary_training_status_message_status_changed(self):
== expected
)

@mock.patch.object(AwsLogsHook, "get_conn")
@mock.patch.object(AwsLogsHook, "conn")
@mock.patch.object(SageMakerHook, "get_conn")
@mock.patch.object(time, "monotonic")
def test_describe_training_job_with_logs_in_progress(self, mock_time, mock_client, mock_log_client):
Expand Down Expand Up @@ -564,7 +564,7 @@ def test_describe_training_job_with_logs_in_progress(self, mock_time, mock_clien
assert response == (LogState.JOB_COMPLETE, {}, 50)

@pytest.mark.parametrize("log_state", [LogState.JOB_COMPLETE, LogState.COMPLETE])
@mock.patch.object(AwsLogsHook, "get_conn")
@mock.patch.object(AwsLogsHook, "conn")
@mock.patch.object(SageMakerHook, "get_conn")
def test_describe_training_job_with_complete_states(self, mock_client, mock_log_client, log_state):
mock_session = mock.Mock()
Expand All @@ -591,7 +591,7 @@ def test_describe_training_job_with_complete_states(self, mock_client, mock_log_
assert response == (LogState.COMPLETE, {}, 0)

@mock.patch.object(SageMakerHook, "check_training_config")
@mock.patch.object(AwsLogsHook, "get_conn")
@mock.patch.object(AwsLogsHook, "conn")
@mock.patch.object(SageMakerHook, "get_conn")
@mock.patch.object(SageMakerHook, "describe_training_job_with_log")
@mock.patch("time.sleep", return_value=None)
Expand Down

0 comments on commit 7e01c09

Please sign in to comment.