Skip to content

Commit

Permalink
[FLINK-14709] Allow outputting elements in user defined close method,…
Browse files Browse the repository at this point in the history
… executed by chained driver.
  • Loading branch information
David Moravek authored and aljoscha committed Nov 28, 2019
1 parent 981b054 commit e583a1c
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -509,10 +509,11 @@ protected void run() throws Exception {
stubOpen = false;
}

this.output.close();

// close all chained tasks letting them report failure
BatchTask.closeChainedTasks(this.chainedTasks, this);

// close the output collector
this.output.close();
}
catch (Exception ex) {
// close the input, but do not report any exceptions, since we already have another root cause
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,12 +208,11 @@ public void invoke() throws Exception {
completedSplitsCounter.inc();
} // end for all input splits

// close the collector. if it is a chaining task collector, it will close its chained tasks
this.output.close();

// close all chained tasks letting them report failure
BatchTask.closeChainedTasks(this.chainedTasks, this);

// close the output collector
this.output.close();
}
catch (Exception ex) {
// close the input, but do not report any exceptions, since we already have another root cause
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ public void run() {
}


private static class InputFilePreparator {
public static class InputFilePreparator {
public static void prepareInputFile(MutableObjectIterator<Record> inIt, File inputFile, boolean insertInvalidData)
throws IOException {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,22 @@

package org.apache.flink.runtime.operators.chaining;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;

import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.GroupCombineFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.operators.util.UserCodeClassWrapper;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.testutils.recordutils.RecordComparatorFactory;
import org.apache.flink.runtime.testutils.recordutils.RecordSerializerFactory;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.operators.DataSourceTask;
import org.apache.flink.runtime.operators.DataSourceTaskTest;
import org.apache.flink.runtime.operators.DriverStrategy;
import org.apache.flink.runtime.operators.BatchTask;
import org.apache.flink.runtime.operators.FlatMapDriver;
Expand All @@ -42,10 +48,15 @@
import org.apache.flink.util.Collector;

import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

public class ChainTaskTest extends TaskTestBase {


@Rule
public TemporaryFolder tempFolder = new TemporaryFolder();

private static final int MEMORY_MANAGER_SIZE = 1024 * 1024 * 3;

private static final int NETWORK_BUFFER_SIZE = 1024;
Expand All @@ -55,7 +66,7 @@ public class ChainTaskTest extends TaskTestBase {
@SuppressWarnings("unchecked")
private final RecordComparatorFactory compFact = new RecordComparatorFactory(new int[]{0}, new Class[]{IntValue.class}, new boolean[] {true});
private final RecordSerializerFactory serFact = RecordSerializerFactory.get();

@Test
public void testMapTask() {
final int keyCnt = 100;
Expand Down Expand Up @@ -98,7 +109,7 @@ public void testMapTask() {
{
registerTask(FlatMapDriver.class, MockMapStub.class);
BatchTask<FlatMapFunction<Record, Record>, Record> testTask = new BatchTask<>(this.mockEnv);

try {
testTask.invoke();
} catch (Exception e) {
Expand Down Expand Up @@ -174,7 +185,68 @@ public void testFailingMapTask() {
Assert.fail(e.getMessage());
}
}


@Test
public void testBatchTaskOutputInCloseMethod() {
final int numChainedTasks = 10;
final int keyCnt = 100;
final int valCnt = 10;
try {
initEnvironment(MEMORY_MANAGER_SIZE, NETWORK_BUFFER_SIZE);
addInput(new UniformRecordGenerator(keyCnt, valCnt, false), 0);
addOutput(outList);
registerTask(FlatMapDriver.class, MockMapStub.class);
for (int i = 0; i < numChainedTasks; i++) {
final TaskConfig taskConfig = new TaskConfig(new Configuration());
taskConfig.addOutputShipStrategy(ShipStrategyType.FORWARD);
taskConfig.setOutputSerializer(serFact);
taskConfig.setStubWrapper(
new UserCodeClassWrapper<>(MockDuplicateLastValueMapFunction.class));
getTaskConfig().addChainedTask(
ChainedFlatMapDriver.class, taskConfig, "chained-" + i);
}
final BatchTask<FlatMapFunction<Record, Record>, Record> testTask =
new BatchTask<>(mockEnv);
testTask.invoke();
Assert.assertEquals(keyCnt * valCnt + numChainedTasks, outList.size());
}
catch (Exception e) {
e.printStackTrace();
Assert.fail(e.getMessage());
}
}

@Test
public void testDataSourceTaskOutputInCloseMethod() throws IOException {
final int numChainedTasks = 10;
final int keyCnt = 100;
final int valCnt = 10;
final File tempTestFile = new File(tempFolder.getRoot(), UUID.randomUUID().toString());
DataSourceTaskTest.InputFilePreparator.prepareInputFile(
new UniformRecordGenerator(keyCnt, valCnt, false), tempTestFile, true);
initEnvironment(MEMORY_MANAGER_SIZE, NETWORK_BUFFER_SIZE);
addOutput(outList);
final DataSourceTask<Record> testTask = new DataSourceTask<>(mockEnv);
registerFileInputTask(
testTask, DataSourceTaskTest.MockInputFormat.class, tempTestFile.toURI().toString(), "\n");
for (int i = 0; i < numChainedTasks; i++) {
final TaskConfig taskConfig = new TaskConfig(new Configuration());
taskConfig.addOutputShipStrategy(ShipStrategyType.FORWARD);
taskConfig.setOutputSerializer(serFact);
taskConfig.setStubWrapper(
new UserCodeClassWrapper<>(ChainTaskTest.MockDuplicateLastValueMapFunction.class));
getTaskConfig().addChainedTask(
ChainedFlatMapDriver.class, taskConfig, "chained-" + i);
}
try {
testTask.invoke();
Assert.assertEquals(keyCnt * valCnt + numChainedTasks, outList.size());
} catch (Exception e) {
e.printStackTrace();
Assert.fail("Invoke method caused exception.");
}
}

public static final class MockFailingCombineStub implements
GroupReduceFunction<Record, Record>,
GroupCombineFunction<Record, Record> {
Expand All @@ -198,4 +270,34 @@ public void combine(Iterable<Record> values, Collector<Record> out) throws Excep
reduce(values, out);
}
}

/**
* FlatMap function that outputs the last emitted element when closing.
*
* @param <T> Input and output type.
*/
public static class MockDuplicateLastValueMapFunction<T> extends RichFlatMapFunction<T, T> {

private boolean closed = false;

private transient T value;
private transient Collector<T> out;

@Override
public void flatMap(T value, Collector<T> out) throws Exception {
if (closed) {
// Make sure we close chained task in proper order.
throw new IllegalStateException("Task is already closed.");
}
this.value = value;
this.out = out;
out.collect(value);
}

@Override
public void close() throws Exception {
closed = true;
out.collect(value);
}
}
}

0 comments on commit e583a1c

Please sign in to comment.