Skip to content

Commit

Permalink
[hotfix] migrate task deployment related tests to junit5 and assertj.
Browse files Browse the repository at this point in the history
  • Loading branch information
huwh authored and wanglijie95 committed Jul 1, 2023
1 parent c174086 commit a68c047
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 106 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@
import org.apache.flink.runtime.clusterframework.types.AllocationID;
import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutorServiceAdapter;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.MaybeOffloaded;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.NonOffloaded;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.Offloaded;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptorFactory.ShuffleDescriptorAndIndex;
import org.apache.flink.runtime.executiongraph.ExecutionGraph;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
Expand All @@ -49,40 +47,38 @@
import org.apache.flink.runtime.scheduler.SchedulerBase;
import org.apache.flink.runtime.shuffle.ShuffleDescriptor;
import org.apache.flink.testutils.TestingUtils;
import org.apache.flink.testutils.executor.TestExecutorResource;
import org.apache.flink.util.CompressedSerializedValue;
import org.apache.flink.util.TestLogger;
import org.apache.flink.testutils.executor.TestExecutorExtension;

import org.junit.ClassRule;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ScheduledExecutorService;

import static org.apache.flink.configuration.JobManagerOptions.HybridPartitionDataConsumeConstraint.ONLY_FINISHED_PRODUCERS;
import static org.junit.Assert.assertEquals;
import static org.apache.flink.runtime.deployment.TaskDeploymentDescriptorTestUtils.deserializeShuffleDescriptors;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

/** Tests for {@link TaskDeploymentDescriptorFactory}. */
public class TaskDeploymentDescriptorFactoryTest extends TestLogger {
@ClassRule
public static final TestExecutorResource<ScheduledExecutorService> EXECUTOR_RESOURCE =
TestingUtils.defaultExecutorResource();
class TaskDeploymentDescriptorFactoryTest {
@RegisterExtension
private static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_RESOURCE =
TestingUtils.defaultExecutorExtension();

private static final int PARALLELISM = 4;

@Test
public void testCacheShuffleDescriptorAsNonOffloaded() throws Exception {
void testCacheShuffleDescriptorAsNonOffloaded() throws Exception {
testCacheShuffleDescriptor(new TestingBlobWriter(Integer.MAX_VALUE));
}

@Test
public void testCacheShuffleDescriptorAsOffloaded() throws Exception {
void testCacheShuffleDescriptorAsOffloaded() throws Exception {
testCacheShuffleDescriptor(new TestingBlobWriter(0));
}

Expand All @@ -106,19 +102,18 @@ private void testCacheShuffleDescriptor(TestingBlobWriter blobWriter) throws Exc
deserializeShuffleDescriptors(maybeOffloaded, jobId, blobWriter);

// Check if the ShuffleDescriptors are cached correctly
assertEquals(ev21.getConsumedPartitionGroup(0).size(), cachedShuffleDescriptors.length);
assertThat(ev21.getConsumedPartitionGroup(0)).hasSize(cachedShuffleDescriptors.length);

int idx = 0;
for (IntermediateResultPartitionID consumedPartitionId :
ev21.getConsumedPartitionGroup(0)) {
assertEquals(
consumedPartitionId,
cachedShuffleDescriptors[idx++].getResultPartitionID().getPartitionId());
assertThat(cachedShuffleDescriptors[idx++].getResultPartitionID().getPartitionId())
.isEqualTo(consumedPartitionId);
}
}

@Test
public void testHybridVertexFinish() throws Exception {
void testHybridVertexFinish() throws Exception {
final Tuple2<ExecutionJobVertex, ExecutionJobVertex> executionJobVertices =
buildExecutionGraph();
final ExecutionJobVertex ejv1 = executionJobVertices.f0;
Expand All @@ -139,7 +134,7 @@ public void testHybridVertexFinish() throws Exception {
consumedResult
.getCachedShuffleDescriptors(ev22.getConsumedPartitionGroup(0))
.getAllSerializedShuffleDescriptors();
assertEquals(maybeOffloaded.size(), 2);
assertThat(maybeOffloaded).hasSize(2);

final ExecutionVertex ev13 = ejv1.getTaskVertices()[2];
ev13.finishPartitionsIfNeeded();
Expand All @@ -150,11 +145,11 @@ public void testHybridVertexFinish() throws Exception {
consumedResult
.getCachedShuffleDescriptors(ev23.getConsumedPartitionGroup(0))
.getAllSerializedShuffleDescriptors();
assertEquals(maybeOffloaded.size(), 3);
assertThat(maybeOffloaded).hasSize(3);
}

@Test(expected = IllegalStateException.class)
public void testGetOffloadedShuffleDescriptorBeforeLoading() throws Exception {
@Test
void testGetOffloadedShuffleDescriptorBeforeLoading() throws Exception {
final TestingBlobWriter blobWriter = new TestingBlobWriter(0);

final JobID jobId = new JobID();
Expand All @@ -166,7 +161,8 @@ public void testGetOffloadedShuffleDescriptorBeforeLoading() throws Exception {
final TaskDeploymentDescriptor tdd = createTaskDeploymentDescriptor(ev21);

// Exception should be thrown when trying to get offloaded shuffle descriptors
tdd.getInputGates().get(0).getShuffleDescriptors();
assertThatThrownBy(() -> tdd.getInputGates().get(0).getShuffleDescriptors())
.isInstanceOf(IllegalStateException.class);
}

private Tuple2<ExecutionJobVertex, ExecutionJobVertex> setupExecutionGraphAndGetVertices(
Expand Down Expand Up @@ -260,43 +256,4 @@ private static TaskDeploymentDescriptor createTaskDeploymentDescriptor(Execution
null,
Collections.emptyList());
}

public static ShuffleDescriptor[] deserializeShuffleDescriptors(
List<MaybeOffloaded<ShuffleDescriptorAndIndex[]>> maybeOffloaded,
JobID jobId,
TestingBlobWriter blobWriter)
throws IOException, ClassNotFoundException {
Map<Integer, ShuffleDescriptor> shuffleDescriptorsMap = new HashMap<>();
int maxIndex = 0;
for (MaybeOffloaded<ShuffleDescriptorAndIndex[]> sd : maybeOffloaded) {
ShuffleDescriptorAndIndex[] shuffleDescriptorAndIndices;
if (sd instanceof NonOffloaded) {
shuffleDescriptorAndIndices =
((NonOffloaded<ShuffleDescriptorAndIndex[]>) sd)
.serializedValue.deserializeValue(
ClassLoader.getSystemClassLoader());

} else {
final CompressedSerializedValue<ShuffleDescriptorAndIndex[]>
compressedSerializedValue =
CompressedSerializedValue.fromBytes(
blobWriter.getBlob(
jobId,
((Offloaded<ShuffleDescriptorAndIndex[]>) sd)
.serializedValueKey));
shuffleDescriptorAndIndices =
compressedSerializedValue.deserializeValue(
ClassLoader.getSystemClassLoader());
}
for (ShuffleDescriptorAndIndex shuffleDescriptorAndIndex :
shuffleDescriptorAndIndices) {
int index = shuffleDescriptorAndIndex.getIndex();
maxIndex = Math.max(maxIndex, shuffleDescriptorAndIndex.getIndex());
shuffleDescriptorsMap.put(index, shuffleDescriptorAndIndex.getShuffleDescriptor());
}
}
ShuffleDescriptor[] shuffleDescriptors = new ShuffleDescriptor[maxIndex + 1];
shuffleDescriptorsMap.forEach((key, value) -> shuffleDescriptors[key] = value);
return shuffleDescriptors;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,8 @@
import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
import org.apache.flink.runtime.operators.BatchTask;
import org.apache.flink.util.SerializedValue;
import org.apache.flink.util.TestLogger;

import org.junit.Test;
import org.junit.jupiter.api.Test;

import javax.annotation.Nonnull;

Expand All @@ -45,14 +44,11 @@
import java.util.List;

import static org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils.createExecutionAttemptId;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.fail;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

/** Tests for the {@link TaskDeploymentDescriptor}. */
public class TaskDeploymentDescriptorTest extends TestLogger {
class TaskDeploymentDescriptorTest {

private static final JobID jobID = new JobID();
private static final JobVertexID vertexID = new JobVertexID();
Expand Down Expand Up @@ -96,10 +92,10 @@ public class TaskDeploymentDescriptorTest extends TestLogger {
invokableClass.getName(),
taskConfiguration));

public TaskDeploymentDescriptorTest() throws IOException {}
TaskDeploymentDescriptorTest() throws IOException {}

@Test
public void testSerialization() throws Exception {
void testSerialization() throws Exception {
final TaskDeploymentDescriptor orig =
createTaskDeploymentDescriptor(
new TaskDeploymentDescriptor.NonOffloaded<>(serializedJobInformation),
Expand All @@ -108,46 +104,44 @@ public void testSerialization() throws Exception {

final TaskDeploymentDescriptor copy = CommonTestUtils.createCopySerializable(orig);

assertFalse(orig.getSerializedJobInformation() == copy.getSerializedJobInformation());
assertFalse(orig.getSerializedTaskInformation() == copy.getSerializedTaskInformation());
assertFalse(orig.getExecutionAttemptId() == copy.getExecutionAttemptId());
assertFalse(orig.getTaskRestore() == copy.getTaskRestore());
assertFalse(orig.getProducedPartitions() == copy.getProducedPartitions());
assertFalse(orig.getInputGates() == copy.getInputGates());

assertEquals(orig.getSerializedJobInformation(), copy.getSerializedJobInformation());
assertEquals(orig.getSerializedTaskInformation(), copy.getSerializedTaskInformation());
assertEquals(orig.getExecutionAttemptId(), copy.getExecutionAttemptId());
assertEquals(orig.getAllocationId(), copy.getAllocationId());
assertEquals(orig.getSubtaskIndex(), copy.getSubtaskIndex());
assertEquals(orig.getAttemptNumber(), copy.getAttemptNumber());
assertEquals(
orig.getTaskRestore().getRestoreCheckpointId(),
copy.getTaskRestore().getRestoreCheckpointId());
assertEquals(
orig.getTaskRestore().getTaskStateSnapshot(),
copy.getTaskRestore().getTaskStateSnapshot());
assertEquals(orig.getProducedPartitions(), copy.getProducedPartitions());
assertEquals(orig.getInputGates(), copy.getInputGates());
assertThat(orig.getSerializedJobInformation())
.isNotSameAs(copy.getSerializedJobInformation());
assertThat(orig.getSerializedTaskInformation())
.isNotSameAs(copy.getSerializedTaskInformation());
assertThat(orig.getExecutionAttemptId()).isNotSameAs(copy.getExecutionAttemptId());
assertThat(orig.getTaskRestore()).isNotSameAs(copy.getTaskRestore());
assertThat(orig.getProducedPartitions()).isNotSameAs(copy.getProducedPartitions());
assertThat(orig.getInputGates()).isNotSameAs(copy.getInputGates());

assertThat(orig.getSerializedJobInformation())
.isEqualTo(copy.getSerializedJobInformation());
assertThat(orig.getSerializedTaskInformation())
.isEqualTo(copy.getSerializedTaskInformation());
assertThat(orig.getExecutionAttemptId()).isEqualTo(copy.getExecutionAttemptId());
assertThat(orig.getAllocationId()).isEqualTo(copy.getAllocationId());
assertThat(orig.getSubtaskIndex()).isEqualTo(copy.getSubtaskIndex());
assertThat(orig.getAttemptNumber()).isEqualTo(copy.getAttemptNumber());
assertThat(orig.getTaskRestore().getRestoreCheckpointId())
.isEqualTo(copy.getTaskRestore().getRestoreCheckpointId());
assertThat(orig.getTaskRestore().getTaskStateSnapshot())
.isEqualTo(copy.getTaskRestore().getTaskStateSnapshot());
assertThat(orig.getProducedPartitions()).isEqualTo(copy.getProducedPartitions());
assertThat(orig.getInputGates()).isEqualTo(copy.getInputGates());
}

@Test
public void testOffLoadedAndNonOffLoadedPayload() {
void testOffLoadedAndNonOffLoadedPayload() {
final TaskDeploymentDescriptor taskDeploymentDescriptor =
createTaskDeploymentDescriptor(
new TaskDeploymentDescriptor.NonOffloaded<>(serializedJobInformation),
new TaskDeploymentDescriptor.Offloaded<>(new PermanentBlobKey()));

SerializedValue<JobInformation> actualSerializedJobInformation =
taskDeploymentDescriptor.getSerializedJobInformation();
assertThat(actualSerializedJobInformation, is(serializedJobInformation));

try {
taskDeploymentDescriptor.getSerializedTaskInformation();
fail("Expected to fail since the task information should be offloaded.");
} catch (IllegalStateException expected) {
// expected
}
assertThat(actualSerializedJobInformation).isSameAs(serializedJobInformation);

assertThatThrownBy(taskDeploymentDescriptor::getSerializedTaskInformation)
.isInstanceOf(IllegalStateException.class);
}

@Nonnull
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License
*/

package org.apache.flink.runtime.deployment;

import org.apache.flink.api.common.JobID;
import org.apache.flink.runtime.blob.TestingBlobWriter;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.MaybeOffloaded;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.NonOffloaded;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.Offloaded;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptorFactory.ShuffleDescriptorAndIndex;
import org.apache.flink.runtime.shuffle.ShuffleDescriptor;
import org.apache.flink.util.CompressedSerializedValue;

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
* A collection of utility methods for testing the TaskDeploymentDescriptor and its related classes.
*/
public class TaskDeploymentDescriptorTestUtils {

public static ShuffleDescriptor[] deserializeShuffleDescriptors(
List<MaybeOffloaded<ShuffleDescriptorAndIndex[]>> maybeOffloaded,
JobID jobId,
TestingBlobWriter blobWriter)
throws IOException, ClassNotFoundException {
Map<Integer, ShuffleDescriptor> shuffleDescriptorsMap = new HashMap<>();
int maxIndex = 0;
for (MaybeOffloaded<ShuffleDescriptorAndIndex[]> sd : maybeOffloaded) {
ShuffleDescriptorAndIndex[] shuffleDescriptorAndIndices;
if (sd instanceof NonOffloaded) {
shuffleDescriptorAndIndices =
((NonOffloaded<ShuffleDescriptorAndIndex[]>) sd)
.serializedValue.deserializeValue(
ClassLoader.getSystemClassLoader());

} else {
final CompressedSerializedValue<ShuffleDescriptorAndIndex[]>
compressedSerializedValue =
CompressedSerializedValue.fromBytes(
blobWriter.getBlob(
jobId,
((Offloaded<ShuffleDescriptorAndIndex[]>) sd)
.serializedValueKey));
shuffleDescriptorAndIndices =
compressedSerializedValue.deserializeValue(
ClassLoader.getSystemClassLoader());
}
for (ShuffleDescriptorAndIndex shuffleDescriptorAndIndex :
shuffleDescriptorAndIndices) {
int index = shuffleDescriptorAndIndex.getIndex();
maxIndex = Math.max(maxIndex, shuffleDescriptorAndIndex.getIndex());
shuffleDescriptorsMap.put(index, shuffleDescriptorAndIndex.getShuffleDescriptor());
}
}
ShuffleDescriptor[] shuffleDescriptors = new ShuffleDescriptor[maxIndex + 1];
shuffleDescriptorsMap.forEach((key, value) -> shuffleDescriptors[key] = value);
return shuffleDescriptors;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeoutException;

import static org.apache.flink.runtime.deployment.TaskDeploymentDescriptorFactoryTest.deserializeShuffleDescriptors;
import static org.apache.flink.runtime.deployment.TaskDeploymentDescriptorTestUtils.deserializeShuffleDescriptors;
import static org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils.finishExecutionVertex;
import static org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils.finishJobVertex;
import static org.assertj.core.api.Assertions.assertThat;
Expand Down

0 comments on commit a68c047

Please sign in to comment.