Skip to content

Commit

Permalink
[FLINK-2662] [dataSet] [optimizer] Fix merging of unions with multipl…
Browse files Browse the repository at this point in the history
…e outputs.

Translate union with N outputs into N unions with single output.

This closes apache#2508.
  • Loading branch information
fhueske committed Sep 20, 2016
1 parent 5c02988 commit 303f6fe
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@
public class OperatorTranslation {

/** The already translated operations */
private Map<DataSet<?>, Operator<?>> translated = new HashMap<DataSet<?>, Operator<?>>();
private Map<DataSet<?>, Operator<?>> translated = new HashMap<>();


public Plan translateToPlan(List<DataSink<?>> sinks, String jobName) {
List<GenericDataSinkBase<?>> planSinks = new ArrayList<GenericDataSinkBase<?>>();
List<GenericDataSinkBase<?>> planSinks = new ArrayList<>();

for (DataSink<?> sink : sinks) {
planSinks.add(translate(sink));
Expand Down Expand Up @@ -74,11 +74,18 @@ private <T> Operator<T> translate(DataSet<T> dataSet) {
}

// check if we have already translated that data set (operation or source)
Operator<?> previous = (Operator<?>) this.translated.get(dataSet);
Operator<?> previous = this.translated.get(dataSet);
if (previous != null) {
@SuppressWarnings("unchecked")
Operator<T> typedPrevious = (Operator<T>) previous;
return typedPrevious;

// Union operators may only have a single output.
// We ensure this by not reusing previously created union operators.
// The optimizer will merge subsequent binary unions into one n-ary union.
if (!(dataSet instanceof UnionOperator)) {
// all other operators are reused.
@SuppressWarnings("unchecked")
Operator<T> typedPrevious = (Operator<T>) previous;
return typedPrevious;
}
}

Operator<T> dataFlowOp;
Expand Down Expand Up @@ -190,7 +197,7 @@ private <T> BulkIterationBase<T> translateBulkIteration(BulkIterationResultSet<?
BulkIterationResultSet<T> iterationEnd = (BulkIterationResultSet<T>) untypedIterationEnd;

BulkIterationBase<T> iterationOperator =
new BulkIterationBase<T>(new UnaryOperatorInformation<T, T>(iterationEnd.getType(), iterationEnd.getType()), "Bulk Iteration");
new BulkIterationBase<>(new UnaryOperatorInformation<>(iterationEnd.getType(), iterationEnd.getType()), "Bulk Iteration");
IterativeDataSet<T> iterationHead = iterationEnd.getIterationHead();

translated.put(iterationHead, iterationOperator.getPartialSolution());
Expand All @@ -216,7 +223,7 @@ private <D, W> DeltaIterationBase<D, W> translateDeltaIteration(DeltaIterationRe

String name = iterationHead.getName() == null ? "Unnamed Delta Iteration" : iterationHead.getName();

DeltaIterationBase<D, W> iterationOperator = new DeltaIterationBase<D, W>(new BinaryOperatorInformation<D, W, D>(iterationEnd.getType(), iterationEnd.getWorksetType(), iterationEnd.getType()),
DeltaIterationBase<D, W> iterationOperator = new DeltaIterationBase<>(new BinaryOperatorInformation<>(iterationEnd.getType(), iterationEnd.getWorksetType(), iterationEnd.getType()),
iterationEnd.getKeyPositions(), name);

iterationOperator.setMaximumNumberOfIterations(iterationEnd.getMaxIterations());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ public void computeInterestingPropertiesForInputs(CostEstimator estimator) {

@Override
public List<PlanNode> getAlternativePlans(CostEstimator estimator) {

// check that union has only a single successor
if (this.getOutgoingConnections().size() > 1) {
throw new CompilerException("BinaryUnionNode has more than one successor.");
}

// check if we have a cached version
if (this.cachedPlans != null) {
return this.cachedPlans;
Expand Down Expand Up @@ -173,7 +179,7 @@ public List<PlanNode> getAlternativePlans(CostEstimator estimator) {
}
}

// create a candidate channel for the first input. mark it cached, if the connection says so
// create a candidate channel for the second input. mark it cached, if the connection says so
Channel c2 = new Channel(child2, this.input2.getMaterializationMode());
if (this.input2.getShipStrategy() == null) {
// free to choose the ship strategy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,25 @@

package org.apache.flink.optimizer;

import junit.framework.Assert;
import org.apache.flink.api.common.operators.util.FieldList;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.common.Plan;
import org.apache.flink.api.java.io.DiscardingOutputFormat;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.optimizer.plan.Channel;
import org.apache.flink.optimizer.plan.NAryUnionPlanNode;
import org.apache.flink.optimizer.plan.OptimizedPlan;
import org.apache.flink.optimizer.plan.SingleInputPlanNode;
import org.apache.flink.optimizer.plantranslate.JobGraphGenerator;
import org.apache.flink.optimizer.util.CompilerTestBase;
import org.apache.flink.runtime.operators.shipping.ShipStrategyType;
import org.junit.Test;

import static org.junit.Assert.fail;
import java.util.List;

import static org.junit.Assert.*;

@SuppressWarnings("serial")
public class UnionReplacementTest extends CompilerTestBase {
Expand All @@ -54,4 +63,95 @@ public void testUnionReplacement() {
fail(e.getMessage());
}
}

/**
*
* Test for FLINK-2662.
*
* Checks that a plan with an union with two outputs is correctly translated.
* The program can be illustrated as follows:
*
* Src1 ----------------\
* >-> Union123 -> GroupBy(0) -> Sum -> Output
* Src2 -\ /
* >-> Union23--<
* Src3 -/ \
* >-> Union234 -> GroupBy(1) -> Sum -> Output
* Src4 ----------------/
*
* The fix for FLINK-2662 translates the union with two output (Union-23) into two separate
* unions (Union-23_1 and Union-23_2) with one output each. Due to this change, the interesting
* partitioning properties for GroupBy(0) and GroupBy(1) are pushed through Union-23_1 and
* Union-23_2 and do not interfere with each other (which would be the case if Union-23 would
* be a single operator with two outputs).
*
*/
@Test
public void testUnionWithTwoOutputsTest() throws Exception {

// -----------------------------------------------------------------------------------------
// Build test program
// -----------------------------------------------------------------------------------------

ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(DEFAULT_PARALLELISM);

DataSet<Tuple2<Long, Long>> src1 = env.fromElements(new Tuple2<>(0L, 0L));
DataSet<Tuple2<Long, Long>> src2 = env.fromElements(new Tuple2<>(0L, 0L));
DataSet<Tuple2<Long, Long>> src3 = env.fromElements(new Tuple2<>(0L, 0L));
DataSet<Tuple2<Long, Long>> src4 = env.fromElements(new Tuple2<>(0L, 0L));

DataSet<Tuple2<Long, Long>> union23 = src2.union(src3);
DataSet<Tuple2<Long, Long>> union123 = src1.union(union23);
DataSet<Tuple2<Long, Long>> union234 = src4.union(union23);

union123.groupBy(0).sum(1).name("1").output(new DiscardingOutputFormat<Tuple2<Long, Long>>());
union234.groupBy(1).sum(0).name("2").output(new DiscardingOutputFormat<Tuple2<Long, Long>>());

// -----------------------------------------------------------------------------------------
// Verify optimized plan
// -----------------------------------------------------------------------------------------

OptimizedPlan optimizedPlan = compileNoStats(env.createProgramPlan());

OptimizerPlanNodeResolver resolver = getOptimizerPlanNodeResolver(optimizedPlan);

SingleInputPlanNode groupRed1 = resolver.getNode("1");
SingleInputPlanNode groupRed2 = resolver.getNode("2");

// check partitioning is correct
Assert.assertTrue("Reduce input should be partitioned on 0.",
groupRed1.getInput().getGlobalProperties().getPartitioningFields().isExactMatch(new FieldList(0)));
Assert.assertTrue("Reduce input should be partitioned on 1.",
groupRed2.getInput().getGlobalProperties().getPartitioningFields().isExactMatch(new FieldList(1)));

// check group reduce inputs are n-ary unions with three inputs
Assert.assertTrue("Reduce input should be n-ary union with three inputs.",
groupRed1.getInput().getSource() instanceof NAryUnionPlanNode &&
((NAryUnionPlanNode) groupRed1.getInput().getSource()).getListOfInputs().size() == 3);
Assert.assertTrue("Reduce input should be n-ary union with three inputs.",
groupRed2.getInput().getSource() instanceof NAryUnionPlanNode &&
((NAryUnionPlanNode) groupRed2.getInput().getSource()).getListOfInputs().size() == 3);

// check channel from union to group reduce is forwarding
Assert.assertTrue("Channel between union and group reduce should be forwarding",
groupRed1.getInput().getShipStrategy().equals(ShipStrategyType.FORWARD));
Assert.assertTrue("Channel between union and group reduce should be forwarding",
groupRed2.getInput().getShipStrategy().equals(ShipStrategyType.FORWARD));

// check that all inputs of unions are hash partitioned
List<Channel> union123In = ((NAryUnionPlanNode) groupRed1.getInput().getSource()).getListOfInputs();
for(Channel i : union123In) {
Assert.assertTrue("Union input channel should hash partition on 0",
i.getShipStrategy().equals(ShipStrategyType.PARTITION_HASH) &&
i.getShipStrategyKeys().isExactMatch(new FieldList(0)));
}
List<Channel> union234In = ((NAryUnionPlanNode) groupRed2.getInput().getSource()).getListOfInputs();
for(Channel i : union234In) {
Assert.assertTrue("Union input channel should hash partition on 0",
i.getShipStrategy().equals(ShipStrategyType.PARTITION_HASH) &&
i.getShipStrategyKeys().isExactMatch(new FieldList(1)));
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.apache.flink.runtime.jobgraph.IntermediateDataSet;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.operators.shipping.ShipStrategyType;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
Expand Down Expand Up @@ -73,9 +74,9 @@ public class UnionClosedBranchingTest extends CompilerTestBase {
@Parameterized.Parameters
public static Collection<Object[]> params() {
Collection<Object[]> params = Arrays.asList(new Object[][]{
{ExecutionMode.PIPELINED, PIPELINED, BATCH},
{ExecutionMode.PIPELINED, BATCH, PIPELINED},
{ExecutionMode.PIPELINED_FORCED, PIPELINED, PIPELINED},
{ExecutionMode.BATCH, BATCH, BATCH},
{ExecutionMode.BATCH, BATCH, PIPELINED},
{ExecutionMode.BATCH_FORCED, BATCH, BATCH},
});

Expand All @@ -93,10 +94,16 @@ public static Collection<Object[]> params() {
/** Expected {@link DataExchangeMode} from union to join. */
private final DataExchangeMode unionToJoin;

/** Expected {@link ShipStrategyType} from source to union. */
private final ShipStrategyType sourceToUnionStrategy = ShipStrategyType.PARTITION_HASH;

/** Expected {@link ShipStrategyType} from union to join. */
private final ShipStrategyType unionToJoinStrategy = ShipStrategyType.FORWARD;

public UnionClosedBranchingTest(
ExecutionMode executionMode,
DataExchangeMode sourceToUnion,
DataExchangeMode unionToJoin) {
ExecutionMode executionMode,
DataExchangeMode sourceToUnion,
DataExchangeMode unionToJoin) {

this.executionMode = executionMode;
this.sourceToUnion = sourceToUnion;
Expand Down Expand Up @@ -140,12 +147,16 @@ public void testUnionClosedBranchingTest() throws Exception {
for (Channel channel : joinNode.getInputs()) {
assertEquals("Unexpected data exchange mode between union and join node.",
unionToJoin, channel.getDataExchangeMode());
assertEquals("Unexpected ship strategy between union and join node.",
unionToJoinStrategy, channel.getShipStrategy());
}

for (SourcePlanNode src : optimizedPlan.getDataSources()) {
for (Channel channel : src.getOutgoingChannels()) {
assertEquals("Unexpected data exchange mode between source and union node.",
sourceToUnion, channel.getDataExchangeMode());
assertEquals("Unexpected ship strategy between source and union node.",
sourceToUnionStrategy, channel.getShipStrategy());
}
}

Expand Down Expand Up @@ -176,9 +187,8 @@ public void testUnionClosedBranchingTest() throws Exception {
for (IntermediateDataSet dataSet : src.getProducedDataSets()) {
ResultPartitionType dsType = dataSet.getResultType();

// The result type is determined by the channel between the union and the join node
// and *not* the channel between source and union.
if (unionToJoin.equals(BATCH)) {
// Ensure batch exchange unless PIPELINED_FORCE is enabled.
if (!executionMode.equals(ExecutionMode.PIPELINED_FORCED)) {
assertTrue("Expected batch exchange, but result type is " + dsType + ".",
dsType.isBlocking());
} else {
Expand Down

0 comments on commit 303f6fe

Please sign in to comment.