From bbbfd1e2f14e62703d2dc5f1cb95631ab10f6d72 Mon Sep 17 00:00:00 2001 From: Edgar Rodriguez Date: Thu, 17 Oct 2024 09:48:54 -0400 Subject: [PATCH] AWS: Fix S3InputStream retry policy (#11335) --- .../apache/iceberg/aws/s3/S3InputStream.java | 26 ++++++++++++++++--- .../aws/s3/TestFlakyS3InputStream.java | 26 +++++++++++++++++++ .../iceberg/aws/s3/TestS3InputStream.java | 12 ++++++--- 3 files changed, 56 insertions(+), 8 deletions(-) diff --git a/aws/src/main/java/org/apache/iceberg/aws/s3/S3InputStream.java b/aws/src/main/java/org/apache/iceberg/aws/s3/S3InputStream.java index 74e602a27378..4af71932e599 100644 --- a/aws/src/main/java/org/apache/iceberg/aws/s3/S3InputStream.java +++ b/aws/src/main/java/org/apache/iceberg/aws/s3/S3InputStream.java @@ -26,6 +26,7 @@ import java.net.SocketException; import java.net.SocketTimeoutException; import java.util.Arrays; +import java.util.List; import javax.net.ssl.SSLException; import org.apache.iceberg.exceptions.NotFoundException; import org.apache.iceberg.io.FileIOMetricsContext; @@ -35,6 +36,7 @@ import org.apache.iceberg.metrics.Counter; import org.apache.iceberg.metrics.MetricsContext; import org.apache.iceberg.metrics.MetricsContext.Unit; +import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; import org.apache.iceberg.relocated.com.google.common.base.Joiner; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; @@ -50,6 +52,9 @@ class S3InputStream extends SeekableInputStream implements RangeReadable { private static final Logger LOG = LoggerFactory.getLogger(S3InputStream.class); + private static final List> RETRYABLE_EXCEPTIONS = + ImmutableList.of(SSLException.class, SocketTimeoutException.class, SocketException.class); + private final StackTraceElement[] createStack; private final S3Client s3; private final S3URI location; @@ -66,10 +71,18 @@ class S3InputStream extends SeekableInputStream implements RangeReadable { private int skipSize = 1024 * 1024; private RetryPolicy retryPolicy = RetryPolicy.builder() - .handle( - ImmutableList.of( - SSLException.class, SocketTimeoutException.class, SocketException.class)) - .onFailure(failure -> openStream(true)) + .handle(RETRYABLE_EXCEPTIONS) + .onRetry( + e -> { + LOG.warn( + "Retrying read from S3, reopening stream (attempt {})", e.getAttemptCount()); + resetForRetry(); + }) + .onFailure( + e -> + LOG.error( + "Failed to read from S3 input stream after exhausting all retries", + e.getException())) .withMaxRetries(3) .build(); @@ -230,6 +243,11 @@ private void openStream(boolean closeQuietly) throws IOException { } } + @VisibleForTesting + void resetForRetry() throws IOException { + openStream(true); + } + private void closeStream(boolean closeQuietly) throws IOException { if (stream != null) { // if we aren't at the end of the stream, and the stream is abortable, then diff --git a/aws/src/test/java/org/apache/iceberg/aws/s3/TestFlakyS3InputStream.java b/aws/src/test/java/org/apache/iceberg/aws/s3/TestFlakyS3InputStream.java index 08d14512cdc7..f98d1a3d4471 100644 --- a/aws/src/test/java/org/apache/iceberg/aws/s3/TestFlakyS3InputStream.java +++ b/aws/src/test/java/org/apache/iceberg/aws/s3/TestFlakyS3InputStream.java @@ -18,6 +18,7 @@ */ package org.apache.iceberg.aws.s3; +import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; @@ -29,6 +30,7 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Stream; import javax.net.ssl.SSLException; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; @@ -49,10 +51,29 @@ public class TestFlakyS3InputStream extends TestS3InputStream { + private AtomicInteger resetForRetryCounter; + + @BeforeEach + public void setupTest() { + resetForRetryCounter = new AtomicInteger(0); + } + + @Override + S3InputStream newInputStream(S3Client s3Client, S3URI uri) { + return new S3InputStream(s3Client, uri) { + @Override + void resetForRetry() throws IOException { + resetForRetryCounter.incrementAndGet(); + super.resetForRetry(); + } + }; + } + @ParameterizedTest @MethodSource("retryableExceptions") public void testReadWithFlakyStreamRetrySucceed(IOException exception) throws Exception { testRead(flakyStreamClient(new AtomicInteger(3), exception)); + assertThat(resetForRetryCounter.get()).isEqualTo(2); } @ParameterizedTest @@ -61,6 +82,7 @@ public void testReadWithFlakyStreamExhaustedRetries(IOException exception) { assertThatThrownBy(() -> testRead(flakyStreamClient(new AtomicInteger(5), exception))) .isInstanceOf(exception.getClass()) .hasMessage(exception.getMessage()); + assertThat(resetForRetryCounter.get()).isEqualTo(3); } @ParameterizedTest @@ -69,12 +91,14 @@ public void testReadWithFlakyStreamNonRetryableException(IOException exception) assertThatThrownBy(() -> testRead(flakyStreamClient(new AtomicInteger(3), exception))) .isInstanceOf(exception.getClass()) .hasMessage(exception.getMessage()); + assertThat(resetForRetryCounter.get()).isEqualTo(0); } @ParameterizedTest @MethodSource("retryableExceptions") public void testSeekWithFlakyStreamRetrySucceed(IOException exception) throws Exception { testSeek(flakyStreamClient(new AtomicInteger(3), exception)); + assertThat(resetForRetryCounter.get()).isEqualTo(2); } @ParameterizedTest @@ -83,6 +107,7 @@ public void testSeekWithFlakyStreamExhaustedRetries(IOException exception) { assertThatThrownBy(() -> testSeek(flakyStreamClient(new AtomicInteger(5), exception))) .isInstanceOf(exception.getClass()) .hasMessage(exception.getMessage()); + assertThat(resetForRetryCounter.get()).isEqualTo(3); } @ParameterizedTest @@ -91,6 +116,7 @@ public void testSeekWithFlakyStreamNonRetryableException(IOException exception) assertThatThrownBy(() -> testSeek(flakyStreamClient(new AtomicInteger(3), exception))) .isInstanceOf(exception.getClass()) .hasMessage(exception.getMessage()); + assertThat(resetForRetryCounter.get()).isEqualTo(0); } private static Stream retryableExceptions() { diff --git a/aws/src/test/java/org/apache/iceberg/aws/s3/TestS3InputStream.java b/aws/src/test/java/org/apache/iceberg/aws/s3/TestS3InputStream.java index 0e3f8b2136a6..f5b78eddaaad 100644 --- a/aws/src/test/java/org/apache/iceberg/aws/s3/TestS3InputStream.java +++ b/aws/src/test/java/org/apache/iceberg/aws/s3/TestS3InputStream.java @@ -57,6 +57,10 @@ public void testRead() throws Exception { testRead(s3); } + S3InputStream newInputStream(S3Client s3Client, S3URI uri) { + return new S3InputStream(s3Client, uri); + } + protected void testRead(S3Client s3Client) throws Exception { S3URI uri = new S3URI("s3://bucket/path/to/read.dat"); int dataSize = 1024 * 1024 * 10; @@ -64,7 +68,7 @@ protected void testRead(S3Client s3Client) throws Exception { writeS3Data(uri, data); - try (SeekableInputStream in = new S3InputStream(s3Client, uri)) { + try (SeekableInputStream in = newInputStream(s3Client, uri)) { int readSize = 1024; readAndCheck(in, in.getPos(), readSize, data, false); readAndCheck(in, in.getPos(), readSize, data, true); @@ -128,7 +132,7 @@ protected void testRangeRead(S3Client s3Client) throws Exception { writeS3Data(uri, expected); - try (RangeReadable in = new S3InputStream(s3Client, uri)) { + try (RangeReadable in = newInputStream(s3Client, uri)) { // first 1k position = 0; offset = 0; @@ -160,7 +164,7 @@ private void readAndCheckRanges( @Test public void testClose() throws Exception { S3URI uri = new S3URI("s3://bucket/path/to/closed.dat"); - SeekableInputStream closed = new S3InputStream(s3, uri); + SeekableInputStream closed = newInputStream(s3, uri); closed.close(); assertThatThrownBy(() -> closed.seek(0)) .isInstanceOf(IllegalStateException.class) @@ -178,7 +182,7 @@ protected void testSeek(S3Client s3Client) throws Exception { writeS3Data(uri, expected); - try (SeekableInputStream in = new S3InputStream(s3Client, uri)) { + try (SeekableInputStream in = newInputStream(s3Client, uri)) { in.seek(expected.length / 2); byte[] actual = new byte[expected.length / 2]; IOUtil.readFully(in, actual, 0, expected.length / 2);