diff --git a/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/ClientTest.java b/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/ClientTest.java index 193508ce16f27..d14bb75015f5c 100644 --- a/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/ClientTest.java +++ b/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/ClientTest.java @@ -23,7 +23,7 @@ import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeutils.base.IntSerializer; import org.apache.flink.core.fs.CloseableRegistry; -import org.apache.flink.core.testutils.FlinkMatchers; +import org.apache.flink.core.testutils.FlinkAssertions; import org.apache.flink.metrics.groups.UnregisteredMetricsGroup; import org.apache.flink.queryablestate.KvStateID; import org.apache.flink.queryablestate.client.VoidNamespace; @@ -45,7 +45,6 @@ import org.apache.flink.runtime.state.memory.MemoryStateBackend; import org.apache.flink.runtime.state.ttl.TtlTimeProvider; import org.apache.flink.util.ExceptionUtils; -import org.apache.flink.util.TestLogger; import org.apache.flink.shaded.netty4.io.netty.bootstrap.ServerBootstrap; import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; @@ -59,11 +58,9 @@ import org.apache.flink.shaded.netty4.io.netty.channel.socket.nio.NioServerSocketChannel; import org.apache.flink.shaded.netty4.io.netty.handler.codec.LengthFieldBasedFrameDecoder; -import org.hamcrest.core.CombinableMatcher; -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -72,7 +69,6 @@ import java.net.InetSocketAddress; import java.net.UnknownHostException; import java.nio.channels.ClosedChannelException; -import java.time.Duration; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -86,28 +82,24 @@ import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.atomic.AtomicReference; -import static org.hamcrest.CoreMatchers.either; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertThat; -import static org.junit.Assert.fail; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; /** Tests for {@link Client}. */ -public class ClientTest extends TestLogger { +class ClientTest { private static final Logger LOG = LoggerFactory.getLogger(ClientTest.class); // Thread pool for client bootstrap (shared between tests) private NioEventLoopGroup nioGroup; - @Before - public void setUp() throws Exception { + @BeforeEach + void setUp() { nioGroup = new NioEventLoopGroup(); } - @After - public void tearDown() throws Exception { + @AfterEach + void tearDown() { if (nioGroup != null) { // note: no "quiet period" to not trigger Netty#4357 nioGroup.shutdownGracefully(); @@ -116,7 +108,7 @@ public void tearDown() throws Exception { /** Tests simple queries, of which half succeed and half fail. */ @Test - public void testSimpleRequests() throws Exception { + void testSimpleRequests() throws Exception { AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats(); MessageSerializer serializer = @@ -156,14 +148,13 @@ public void testSimpleRequests() throws Exception { for (long i = 0L; i < numQueries; i++) { ByteBuf buf = received.take(); - assertNotNull("Receive timed out", buf); + assertThat(buf).withFailMessage("Receive timed out").isNotNull(); Channel ch = channel.get(); - assertNotNull("Channel not active", ch); + assertThat(ch).withFailMessage("Channel not active").isNotNull(); - assertEquals(MessageType.REQUEST, MessageSerializer.deserializeHeader(buf)); + assertThat(MessageType.REQUEST).isEqualTo(MessageSerializer.deserializeHeader(buf)); long requestId = MessageSerializer.getRequestId(buf); - KvStateInternalRequest deserRequest = serializer.deserializeRequest(buf); buf.release(); @@ -188,22 +179,16 @@ public void testSimpleRequests() throws Exception { if (i % 2L == 0L) { KvStateResponse serializedResult = futures.get((int) i).get(); - assertArrayEquals(expected, serializedResult.getContent()); + assertThat(expected).containsExactly(serializedResult.getContent()); } else { - try { - futures.get((int) i).get(); - fail("Did not throw expected Exception"); - } catch (ExecutionException e) { - - if (!(e.getCause() instanceof RuntimeException)) { - fail("Did not throw expected Exception"); - } - // else expected - } + CompletableFuture future = futures.get((int) i); + FlinkAssertions.assertThatFuture(future) + .eventuallyFailsWith(ExecutionException.class) + .satisfies(FlinkAssertions.anyCauseMatches(RuntimeException.class)); } } - assertEquals(numQueries, stats.getNumRequests()); + assertThat(numQueries).isEqualTo(stats.getNumRequests()); long expectedRequests = numQueries / 2L; // Counts can take some time to propagate @@ -212,8 +197,8 @@ public void testSimpleRequests() throws Exception { Thread.sleep(100L); } - assertEquals(expectedRequests, stats.getNumSuccessful()); - assertEquals(expectedRequests, stats.getNumFailed()); + assertThat(expectedRequests).isEqualTo(stats.getNumSuccessful()); + assertThat(expectedRequests).isEqualTo(stats.getNumFailed()); } finally { if (client != null) { Exception exc = null; @@ -230,21 +215,22 @@ public void testSimpleRequests() throws Exception { LOG.error("An exception occurred while shutting down netty.", e); } - Assert.assertTrue( - ExceptionUtils.stringifyException(exc), client.isEventGroupShutdown()); + assertThat(client.isEventGroupShutdown()) + .withFailMessage(ExceptionUtils.stringifyException(exc)) + .isTrue(); } if (serverChannel != null) { serverChannel.close(); } - assertEquals("Channel leak", 0L, stats.getNumConnections()); + assertThat(stats.getNumConnections()).withFailMessage("Channel leak").isZero(); } } /** Tests that a request to an unavailable host is failed with ConnectException. */ @Test - public void testRequestUnavailableHost() throws Exception { + void testRequestUnavailableHost() throws Exception { AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats(); MessageSerializer serializer = @@ -263,10 +249,8 @@ public void testRequestUnavailableHost() throws Exception { new KvStateInternalRequest(new KvStateID(), new byte[0]); CompletableFuture future = client.sendRequest(serverAddress, request); - assertThat( - future, - FlinkMatchers.futureWillCompleteExceptionally( - ConnectException.class, Duration.ofHours(1))); + assertThat(future).isNotNull(); + assertThatThrownBy(future::get).hasRootCauseInstanceOf(ConnectException.class); } finally { if (client != null) { try { @@ -274,16 +258,16 @@ public void testRequestUnavailableHost() throws Exception { } catch (Exception e) { e.printStackTrace(); } - Assert.assertTrue(client.isEventGroupShutdown()); + assertThat(client.isEventGroupShutdown()).isTrue(); } - assertEquals("Channel leak", 0L, stats.getNumConnections()); + assertThat(stats.getNumConnections()).withFailMessage("Channel leak").isZero(); } } /** Multiple threads concurrently fire queries. */ @Test - public void testConcurrentQueries() throws Exception { + void testConcurrentQueries() throws Exception { AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats(); final MessageSerializer serializer = @@ -337,7 +321,7 @@ public void testConcurrentQueries() throws Exception { List> results = future.get(); for (CompletableFuture result : results) { KvStateResponse actual = result.get(); - assertArrayEquals(serializedResult, actual.getContent()); + assertThat(serializedResult).containsExactly(actual.getContent()); } } @@ -348,8 +332,8 @@ public void testConcurrentQueries() throws Exception { Thread.sleep(100L); } - assertEquals(totalQueries, stats.getNumRequests()); - assertEquals(totalQueries, stats.getNumSuccessful()); + assertThat(totalQueries).isEqualTo(stats.getNumRequests()); + assertThat(totalQueries).isEqualTo(stats.getNumSuccessful()); } finally { if (executor != null) { executor.shutdown(); @@ -365,10 +349,10 @@ public void testConcurrentQueries() throws Exception { } catch (Exception e) { e.printStackTrace(); } - Assert.assertTrue(client.isEventGroupShutdown()); + assertThat(client.isEventGroupShutdown()).isTrue(); } - assertEquals("Channel leak", 0L, stats.getNumConnections()); + assertThat(stats.getNumConnections()).withFailMessage("Channel leak").isZero(); } } @@ -377,7 +361,7 @@ public void testConcurrentQueries() throws Exception { * connections. */ @Test - public void testFailureClosesChannel() throws Exception { + void testFailureClosesChannel() throws Exception { AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats(); final MessageSerializer serializer = @@ -400,7 +384,7 @@ public void testFailureClosesChannel() throws Exception { InetSocketAddress serverAddress = getKvStateServerAddress(serverChannel); // Requests - List> futures = new ArrayList<>(); + List> futures = new ArrayList<>(); KvStateInternalRequest request = new KvStateInternalRequest(new KvStateID(), new byte[0]); @@ -408,17 +392,17 @@ public void testFailureClosesChannel() throws Exception { futures.add(client.sendRequest(serverAddress, request)); ByteBuf buf = received.take(); - assertNotNull("Receive timed out", buf); + assertThat(buf).withFailMessage("Receive timed out").isNotNull(); buf.release(); buf = received.take(); - assertNotNull("Receive timed out", buf); + assertThat(buf).withFailMessage("Receive timed out").isNotNull(); buf.release(); - assertEquals(1L, stats.getNumConnections()); + assertThat(stats.getNumConnections()).isEqualTo(1L); Channel ch = channel.get(); - assertNotNull("Channel not active", ch); + assertThat(ch).withFailMessage("Channel not active").isNotNull(); // Respond with failure ch.writeAndFlush( @@ -426,38 +410,26 @@ public void testFailureClosesChannel() throws Exception { serverChannel.alloc(), new RuntimeException("Expected test server failure"))); - try { - futures.remove(0).get(); - fail("Did not throw expected server failure"); - } catch (ExecutionException e) { - - if (!(e.getCause() instanceof RuntimeException)) { - fail("Did not throw expected Exception"); - } - // Expected - } - - try { - futures.remove(0).get(); - fail("Did not throw expected server failure"); - } catch (ExecutionException e) { + CompletableFuture removedFuture = futures.remove(0); + FlinkAssertions.assertThatFuture(removedFuture) + .eventuallyFailsWith(ExecutionException.class) + .satisfies(FlinkAssertions.anyCauseMatches(RuntimeException.class)); - if (!(e.getCause() instanceof RuntimeException)) { - fail("Did not throw expected Exception"); - } - // Expected - } + removedFuture = futures.remove(0); + FlinkAssertions.assertThatFuture(removedFuture) + .eventuallyFailsWith(ExecutionException.class) + .satisfies(FlinkAssertions.anyCauseMatches(RuntimeException.class)); - assertEquals(0L, stats.getNumConnections()); + assertThat(stats.getNumConnections()).isZero(); // Counts can take some time to propagate while (stats.getNumSuccessful() != 0L || stats.getNumFailed() != 2L) { Thread.sleep(100L); } - assertEquals(2L, stats.getNumRequests()); - assertEquals(0L, stats.getNumSuccessful()); - assertEquals(2L, stats.getNumFailed()); + assertThat(stats.getNumRequests()).isEqualTo(2L); + assertThat(stats.getNumSuccessful()).isZero(); + assertThat(stats.getNumFailed()).isEqualTo(2L); } finally { if (client != null) { try { @@ -465,14 +437,14 @@ public void testFailureClosesChannel() throws Exception { } catch (Exception e) { e.printStackTrace(); } - Assert.assertTrue(client.isEventGroupShutdown()); + assertThat(client.isEventGroupShutdown()).isTrue(); } if (serverChannel != null) { serverChannel.close(); } - assertEquals("Channel leak", 0L, stats.getNumConnections()); + assertThat(stats.getNumConnections()).withFailMessage("Channel leak").isZero(); } } @@ -481,7 +453,7 @@ public void testFailureClosesChannel() throws Exception { * connections. */ @Test - public void testServerClosesChannel() throws Exception { + void testServerClosesChannel() throws Exception { AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats(); final MessageSerializer serializer = @@ -506,34 +478,28 @@ public void testServerClosesChannel() throws Exception { // Requests KvStateInternalRequest request = new KvStateInternalRequest(new KvStateID(), new byte[0]); - Future future = client.sendRequest(serverAddress, request); + CompletableFuture future = client.sendRequest(serverAddress, request); received.take(); - assertEquals(1, stats.getNumConnections()); + assertThat(stats.getNumConnections()).isEqualTo(1); channel.get().close().await(); - try { - future.get(); - fail("Did not throw expected server failure"); - } catch (ExecutionException e) { - if (!(e.getCause() instanceof ClosedChannelException)) { - fail("Did not throw expected Exception"); - } - // Expected - } + FlinkAssertions.assertThatFuture(future) + .eventuallyFailsWith(ExecutionException.class) + .satisfies(FlinkAssertions.anyCauseMatches(ClosedChannelException.class)); - assertEquals(0L, stats.getNumConnections()); + assertThat(stats.getNumConnections()).isZero(); // Counts can take some time to propagate while (stats.getNumSuccessful() != 0L || stats.getNumFailed() != 1L) { Thread.sleep(100L); } - assertEquals(1L, stats.getNumRequests()); - assertEquals(0L, stats.getNumSuccessful()); - assertEquals(1L, stats.getNumFailed()); + assertThat(stats.getNumRequests()).isEqualTo(1L); + assertThat(stats.getNumSuccessful()).isZero(); + assertThat(stats.getNumFailed()).isEqualTo(1L); } finally { if (client != null) { try { @@ -541,14 +507,14 @@ public void testServerClosesChannel() throws Exception { } catch (Exception e) { e.printStackTrace(); } - Assert.assertTrue(client.isEventGroupShutdown()); + assertThat(client.isEventGroupShutdown()).isTrue(); } if (serverChannel != null) { serverChannel.close(); } - assertEquals("Channel leak", 0L, stats.getNumConnections()); + assertThat(stats.getNumConnections()).withFailMessage("Channel leak").isZero(); } } @@ -557,7 +523,7 @@ public void testServerClosesChannel() throws Exception { * this point, the client is shut down and its verified that all ongoing requests are failed. */ @Test - public void testClientServerIntegration() throws Throwable { + void testClientServerIntegration() throws Throwable { // Config final int numServers = 2; final int numServerEventLoopThreads = 2; @@ -699,7 +665,7 @@ public void testClientServerIntegration() throws Throwable { int value = KvStateSerializer.deserializeValue( buf, IntSerializer.INSTANCE); - assertEquals(201L + targetServer, value); + assertThat(value).isEqualTo(201L + targetServer); } } }; @@ -721,29 +687,33 @@ public void testClientServerIntegration() throws Throwable { } catch (Exception e) { e.printStackTrace(); } - Assert.assertTrue(client.isEventGroupShutdown()); - - final CombinableMatcher exceptionMatcher = - either(FlinkMatchers.containsCause(ClosedChannelException.class)) - .or(FlinkMatchers.containsCause(IllegalStateException.class)); + assertThat(client.isEventGroupShutdown()).isTrue(); for (Future future : taskFutures) { try { future.get(); - fail("Did not throw expected Exception after shut down"); - } catch (ExecutionException t) { - assertThat(t, exceptionMatcher); + } catch (Throwable throwable) { + FlinkAssertions.assertThatChainOfCauses(throwable) + .anySatisfy( + cause -> + assertThat(cause) + .isInstanceOfAny( + ClosedChannelException.class, + IllegalStateException.class)); } } - assertEquals("Connection leak (client)", 0L, clientStats.getNumConnections()); + assertThat(clientStats.getNumConnections()) + .withFailMessage("Connection leak (client)") + .isZero(); for (int i = 0; i < numServers; i++) { boolean success = false; int numRetries = 0; while (!success) { try { - assertEquals( - "Connection leak (server)", 0L, serverStats[i].getNumConnections()); + assertThat(serverStats[i].getNumConnections()) + .withFailMessage("Connection leak (server)") + .isZero(); success = true; } catch (Throwable t) { if (numRetries < 10) { @@ -763,7 +733,7 @@ public void testClientServerIntegration() throws Throwable { } catch (Exception e) { e.printStackTrace(); } - Assert.assertTrue(client.isEventGroupShutdown()); + assertThat(client.isEventGroupShutdown()).isTrue(); } for (int i = 0; i < numServers; i++) { @@ -845,7 +815,7 @@ private RespondingChannelHandler( @Override public void channelRead(ChannelHandlerContext ctx, Object msg) { ByteBuf buf = (ByteBuf) msg; - assertEquals(MessageType.REQUEST, MessageSerializer.deserializeHeader(buf)); + assertThat(MessageSerializer.deserializeHeader(buf)).isEqualTo(MessageType.REQUEST); long requestId = MessageSerializer.getRequestId(buf); KvStateInternalRequest request = serializer.deserializeRequest(buf);