Skip to content

Commit

Permalink
[FLINK-11250][runtime] Added method init for RecordWriter for initial…
Browse files Browse the repository at this point in the history
…ization resources(OutputFlusher) outside of constructor (apache#17187)

* [refactor][streaming] Ability to change bufferTimeout for StreamEdge in StreamConfigChainer

* [FLINK-11250][streaming] Correctly clean up stream task on every place it uses
  • Loading branch information
akalash authored Sep 17, 2021
1 parent 164a59a commit 3b6b522
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -751,18 +751,6 @@ private void doRun() {
// by the time we switched to running.
this.invokable = invokable;

// switch to the INITIALIZING state, if that fails, we have been canceled/failed in the
// meantime
if (!transitionState(ExecutionState.DEPLOYING, ExecutionState.INITIALIZING)) {
throw new CancelTaskException();
}

taskManagerActions.updateTaskExecutionState(
new TaskExecutionState(executionId, ExecutionState.INITIALIZING));

// make sure the user code classloader is accessible thread-locally
executingThread.setContextClassLoader(userCodeClassLoader.asClassLoader());

restoreAndInvoke(invokable);

// make sure, we enter the catch block if the task leaves the invoke() method due
Expand Down Expand Up @@ -924,6 +912,18 @@ else if (transitionState(current, ExecutionState.FAILED, t)) {

private void restoreAndInvoke(TaskInvokable finalInvokable) throws Exception {
try {
// switch to the INITIALIZING state, if that fails, we have been canceled/failed in the
// meantime
if (!transitionState(ExecutionState.DEPLOYING, ExecutionState.INITIALIZING)) {
throw new CancelTaskException();
}

taskManagerActions.updateTaskExecutionState(
new TaskExecutionState(executionId, ExecutionState.INITIALIZING));

// make sure the user code classloader is accessible thread-locally
executingThread.setContextClassLoader(userCodeClassLoader.asClassLoader());

runWithSystemExitMonitoring(finalInvokable::restore);

if (!transitionState(ExecutionState.INITIALIZING, ExecutionState.RUNNING)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,26 @@ public void testCleanupWhenAfterInvokeSucceeded() throws Exception {
assertTrue(wasCleanedUp);
}

@Test
public void testCleanupWhenSwitchToInitializationFails() throws Exception {
createTaskBuilder()
.setInvokable(TestInvokableCorrect.class)
.setTaskManagerActions(
new NoOpTaskManagerActions() {
@Override
public void updateTaskExecutionState(
TaskExecutionState taskExecutionState) {
if (taskExecutionState.getExecutionState()
== ExecutionState.INITIALIZING) {
throw new ExpectedTestException();
}
}
})
.build()
.run();
assertTrue(wasCleanedUp);
}

@Test
public void testRegularExecution() throws Exception {
final QueuedNoOpTaskManagerActions taskManagerActions = new QueuedNoOpTaskManagerActions();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ public class StreamConfigChainer<OWNER> {
private final StreamConfig headConfig;
private final Map<Integer, StreamConfig> chainedConfigs = new HashMap<>();
private final int numberOfNonChainedOutputs;
private int bufferTimeout;

private StreamConfig tailConfig;
private int chainIndex = MAIN_NODE_ID;
Expand Down Expand Up @@ -127,26 +128,22 @@ public <IN, OUT> StreamConfigChainer<OWNER> chain(

chainIndex++;

tailConfig.setChainedOutputs(
Collections.singletonList(
new StreamEdge(
new StreamNode(
tailConfig.getChainIndex(),
null,
null,
(StreamOperator<?>) null,
null,
null),
new StreamNode(
chainIndex,
null,
null,
(StreamOperator<?>) null,
null,
null),
0,
StreamEdge streamEdge =
new StreamEdge(
new StreamNode(
tailConfig.getChainIndex(),
null,
null)));
null,
(StreamOperator<?>) null,
null,
null),
new StreamNode(
chainIndex, null, null, (StreamOperator<?>) null, null, null),
0,
null,
null);
streamEdge.setBufferTimeout(bufferTimeout);
tailConfig.setChainedOutputs(Collections.singletonList(streamEdge));
tailConfig = new StreamConfig(new Configuration());
tailConfig.setStreamOperatorFactory(checkNotNull(operatorFactory));
tailConfig.setOperatorID(checkNotNull(operatorID));
Expand All @@ -173,7 +170,7 @@ public OWNER finish() {
StreamNode sourceVertex =
new StreamNode(chainIndex, null, null, (StreamOperator<?>) null, null, null);
for (int i = 0; i < numberOfNonChainedOutputs; ++i) {
outEdgesInOrder.add(
StreamEdge streamEdge =
new StreamEdge(
sourceVertex,
new StreamNode(
Expand All @@ -185,7 +182,9 @@ public OWNER finish() {
null),
0,
new BroadcastPartitioner<>(),
null));
null);
streamEdge.setBufferTimeout(1);
outEdgesInOrder.add(streamEdge);
}

tailConfig.setChainEnd();
Expand Down Expand Up @@ -250,4 +249,8 @@ public StreamConfigChainer<OWNER> name(String name) {
tailConfig.setOperatorName(name);
return this;
}

public void setBufferTimeout(int bufferTimeout) {
this.bufferTimeout = bufferTimeout;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
import org.apache.flink.runtime.io.network.NettyShuffleEnvironment;
import org.apache.flink.runtime.io.network.NettyShuffleEnvironmentBuilder;
import org.apache.flink.runtime.io.network.api.writer.AvailabilityTestResultPartitionWriter;
import org.apache.flink.runtime.io.network.api.writer.RecordWriter;
import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
import org.apache.flink.runtime.io.network.partition.consumer.TestInputChannel;
Expand Down Expand Up @@ -123,7 +122,6 @@
import org.apache.flink.streaming.runtime.tasks.mailbox.MailboxDefaultAction;
import org.apache.flink.streaming.util.MockStreamConfig;
import org.apache.flink.streaming.util.MockStreamTaskBuilder;
import org.apache.flink.streaming.util.TestSequentialReadingStreamOperator;
import org.apache.flink.util.CloseableIterable;
import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.util.FatalExitExceptionHandler;
Expand All @@ -134,9 +132,11 @@
import org.apache.flink.util.concurrent.FutureUtils;
import org.apache.flink.util.concurrent.TestingUncaughtExceptionHandler;
import org.apache.flink.util.function.BiConsumerWithException;
import org.apache.flink.util.function.FunctionWithException;
import org.apache.flink.util.function.RunnableWithException;
import org.apache.flink.util.function.SupplierWithException;

import org.hamcrest.CoreMatchers;
import org.hamcrest.Matchers;
import org.junit.Assert;
import org.junit.Rule;
Expand All @@ -151,7 +151,6 @@
import java.io.Closeable;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.StreamCorruptedException;
import java.nio.ByteBuffer;
import java.time.Duration;
import java.util.ArrayList;
Expand Down Expand Up @@ -183,10 +182,13 @@
import static org.apache.flink.configuration.TaskManagerOptions.BUFFER_DEBLOAT_TARGET;
import static org.apache.flink.runtime.checkpoint.CheckpointFailureReason.UNKNOWN_TASK_CHECKPOINT_NOTIFICATION_FAILURE;
import static org.apache.flink.runtime.checkpoint.StateObjectCollection.singleton;
import static org.apache.flink.runtime.io.network.api.writer.RecordWriter.DEFAULT_OUTPUT_FLUSH_THREAD_NAME;
import static org.apache.flink.runtime.state.CheckpointStorageLocationReference.getDefault;
import static org.apache.flink.streaming.runtime.tasks.mailbox.TaskMailbox.MAX_PRIORITY;
import static org.apache.flink.streaming.util.StreamTaskUtil.waitTaskIsRunning;
import static org.apache.flink.util.Preconditions.checkState;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.instanceOf;
Expand Down Expand Up @@ -1318,58 +1320,96 @@ public void testThreadInvariants() throws Throwable {
}
}

/**
* This test ensures that {@link RecordWriter} is correctly closed even if we fail to construct
* {@link OperatorChain}, for example because of user class deserialization error.
*/
@Test
public void testRecordWriterClosedOnStreamOperatorFactoryDeserializationError()
public void testRecordWriterClosedOnTransitDeployingStateError() throws Exception {
testRecordWriterClosedOnTransitStateError(ExecutionState.DEPLOYING);
}

@Test
public void testRecordWriterClosedOnTransitInitializingStateError() throws Exception {
testRecordWriterClosedOnTransitStateError(ExecutionState.INITIALIZING);
}

@Test
public void testRecordWriterClosedOnTransitRunningStateError() throws Exception {
testRecordWriterClosedOnTransitStateError(ExecutionState.RUNNING);
}

private void testRecordWriterClosedOnTransitStateError(ExecutionState executionState)
throws Exception {
// Throw the exception when the state updating to the expected one.
NoOpTaskManagerActions taskManagerActions =
new NoOpTaskManagerActions() {
@Override
public void updateTaskExecutionState(TaskExecutionState taskExecutionState) {
if (taskExecutionState.getExecutionState() == executionState) {
throw new ExpectedTestException();
}
}
};

testRecordWriterClosedOnError(
env ->
taskBuilderWithConfiguredRecordWriter(env)
.setTaskManagerActions(taskManagerActions)
.build());
}

private void testRecordWriterClosedOnError(
FunctionWithException<NettyShuffleEnvironment, Task, Exception> taskProvider)
throws Exception {
try (NettyShuffleEnvironment shuffleEnvironment =
new NettyShuffleEnvironmentBuilder().build()) {
Task task = taskProvider.apply(shuffleEnvironment);

task.startTaskThread();
task.getExecutingThread().join();

assertEquals(ExecutionState.FAILED, task.getExecutionState());
for (Thread thread : Thread.getAllStackTraces().keySet()) {
assertThat(
thread.getName(),
CoreMatchers.is(not(containsString(DEFAULT_OUTPUT_FLUSH_THREAD_NAME))));
}
}
}

private TestTaskBuilder taskBuilderWithConfiguredRecordWriter(
NettyShuffleEnvironment shuffleEnvironment) {
Configuration taskConfiguration = new Configuration();
outputEdgeConfiguration(taskConfiguration);

ResultPartitionDeploymentDescriptor descriptor =
new ResultPartitionDeploymentDescriptor(
PartitionDescriptorBuilder.newBuilder().build(),
NettyShuffleDescriptorBuilder.newBuilder().buildLocal(),
1,
false);
return new TestTaskBuilder(shuffleEnvironment)
.setInvokable(NoOpStreamTask.class)
.setTaskConfig(taskConfiguration)
.setResultPartitions(singletonList(descriptor));
}

/**
* Make sure that there is some output edge in the config so that some RecordWriter is created.
*/
private void outputEdgeConfiguration(Configuration taskConfiguration) {
StreamConfig streamConfig = new StreamConfig(taskConfiguration);
streamConfig.setStreamOperatorFactory(new UnusedOperatorFactory());

// Make sure that there is some output edge in the config so that some RecordWriter is
// created
StreamConfigChainer cfg =
new StreamConfigChainer(new OperatorID(42, 42), streamConfig, this, 1);
// The OutputFlusher thread is started only if the buffer timeout more than 0(default value
// is 0).
cfg.setBufferTimeout(1);
cfg.chain(
new OperatorID(44, 44),
new UnusedOperatorFactory(),
StringSerializer.INSTANCE,
StringSerializer.INSTANCE,
false);
cfg.finish();

// Overwrite the serialized bytes to some garbage to induce deserialization exception
taskConfiguration.setBytes(StreamConfig.SERIALIZEDUDF, new byte[42]);

try (MockEnvironment mockEnvironment =
new MockEnvironmentBuilder().setTaskConfiguration(taskConfiguration).build()) {

mockEnvironment.addOutput(new ArrayList<>());
StreamTask<String, TestSequentialReadingStreamOperator> streamTask =
new NoOpStreamTask<>(mockEnvironment);

try {
streamTask.invoke();
fail("Should have failed with an exception!");
} catch (Exception ex) {
if (!ExceptionUtils.findThrowable(ex, StreamCorruptedException.class).isPresent()) {
throw ex;
}
}
}

assertTrue(
RecordWriter.DEFAULT_OUTPUT_FLUSH_THREAD_NAME + " thread is still running",
Thread.getAllStackTraces().keySet().stream()
.noneMatch(
thread ->
thread.getName()
.startsWith(
RecordWriter
.DEFAULT_OUTPUT_FLUSH_THREAD_NAME)));
}

@Test
Expand Down

0 comments on commit 3b6b522

Please sign in to comment.