diff --git a/flink-tests/src/test/java/org/apache/flink/test/iterative/aggregators/AggregatorConvergenceITCase.java b/flink-tests/src/test/java/org/apache/flink/test/iterative/aggregators/AggregatorConvergenceITCase.java new file mode 100644 index 0000000000000..941b31b4406a2 --- /dev/null +++ b/flink-tests/src/test/java/org/apache/flink/test/iterative/aggregators/AggregatorConvergenceITCase.java @@ -0,0 +1,350 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.test.iterative.aggregators; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import org.apache.flink.api.common.aggregators.ConvergenceCriterion; +import org.apache.flink.api.common.aggregators.LongSumAggregator; +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.functions.RichJoinFunction; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.test.util.JavaProgramTestBase; +import org.apache.flink.test.util.MultipleProgramsTestBase; +import org.apache.flink.types.LongValue; +import org.apache.flink.util.Collector; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.api.java.operators.IterativeDataSet; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import static org.junit.Assert.*; + +/** + * Connected Components test case that uses a parameterizable convergence criterion + */ +@RunWith(Parameterized.class) +@SuppressWarnings("serial") +public class AggregatorConvergenceITCase extends MultipleProgramsTestBase { + + public AggregatorConvergenceITCase(TestExecutionMode mode) { + super(mode); + } + + @Test + public void testConnectedComponentsWithParametrizableConvergence() { + try { + List> verticesInput = Arrays.asList( + new Tuple2(1l,1l), + new Tuple2(2l,2l), + new Tuple2(3l,3l), + new Tuple2(4l,4l), + new Tuple2(5l,5l), + new Tuple2(6l,6l), + new Tuple2(7l,7l), + new Tuple2(8l,8l), + new Tuple2(9l,9l) + ); + + List> edgesInput = Arrays.asList( + new Tuple2(1l,2l), + new Tuple2(1l,3l), + new Tuple2(2l,3l), + new Tuple2(2l,4l), + new Tuple2(2l,1l), + new Tuple2(3l,1l), + new Tuple2(3l,2l), + new Tuple2(4l,2l), + new Tuple2(4l,6l), + new Tuple2(5l,6l), + new Tuple2(6l,4l), + new Tuple2(6l,5l), + new Tuple2(7l,8l), + new Tuple2(7l,9l), + new Tuple2(8l,7l), + new Tuple2(8l,9l), + new Tuple2(9l,7l), + new Tuple2(9l,8l) + ); + + // name of the aggregator that checks for convergence + final String UPDATED_ELEMENTS = "updated.elements.aggr"; + + // the iteration stops if less than this number os elements change value + final long convergence_threshold = 3; + + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet> initialSolutionSet = env.fromCollection(verticesInput); + DataSet> edges = env.fromCollection(edgesInput); + + IterativeDataSet> iteration = + initialSolutionSet.iterate(10); + + // register the convergence criterion + iteration.registerAggregationConvergenceCriterion(UPDATED_ELEMENTS, + new LongSumAggregator(), new UpdatedElementsConvergenceCriterion(convergence_threshold)); + + DataSet> verticesWithNewComponents = iteration.join(edges).where(0).equalTo(0) + .with(new NeighborWithComponentIDJoin()) + .groupBy(0).min(1); + + DataSet> updatedComponentId = + verticesWithNewComponents.join(iteration).where(0).equalTo(0) + .flatMap(new MinimumIdFilter(UPDATED_ELEMENTS)); + + List> result = iteration.closeWith(updatedComponentId).collect(); + Collections.sort(result, new JavaProgramTestBase.TupleComparator>()); + + List> expectedResult = Arrays.asList( + new Tuple2(1L,1L), + new Tuple2(2L,1L), + new Tuple2(3L,1L), + new Tuple2(4L,1L), + new Tuple2(5L,2L), + new Tuple2(6L,1L), + new Tuple2(7L,7L), + new Tuple2(8L,7L), + new Tuple2(9L,7L) + ); + + assertEquals(expectedResult, result); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testParameterizableAggregator() { + try { + List> verticesInput = Arrays.asList( + new Tuple2(1l,1l), + new Tuple2(2l,2l), + new Tuple2(3l,3l), + new Tuple2(4l,4l), + new Tuple2(5l,5l), + new Tuple2(6l,6l), + new Tuple2(7l,7l), + new Tuple2(8l,8l), + new Tuple2(9l,9l) + ); + + List> edgesInput = Arrays.asList( + new Tuple2<>(1l,2l), + new Tuple2<>(1l,3l), + new Tuple2<>(2l,3l), + new Tuple2<>(2l,4l), + new Tuple2<>(2l,1l), + new Tuple2<>(3l,1l), + new Tuple2<>(3l,2l), + new Tuple2<>(4l,2l), + new Tuple2<>(4l,6l), + new Tuple2<>(5l,6l), + new Tuple2<>(6l,4l), + new Tuple2<>(6l,5l), + new Tuple2<>(7l,8l), + new Tuple2<>(7l,9l), + new Tuple2<>(8l,7l), + new Tuple2<>(8l,9l), + new Tuple2<>(9l,7l), + new Tuple2<>(9l,8l) + ); + + final int MAX_ITERATIONS = 5; + final String AGGREGATOR_NAME = "elements.in.component.aggregator"; + final long componentId = 1l; + + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet> initialSolutionSet = env.fromCollection(verticesInput); + DataSet> edges = env.fromCollection(edgesInput); + + IterativeDataSet> iteration = + initialSolutionSet.iterate(MAX_ITERATIONS); + + // register the aggregator + iteration.registerAggregator(AGGREGATOR_NAME, new LongSumAggregatorWithParameter(componentId)); + + DataSet> verticesWithNewComponents = iteration.join(edges).where(0).equalTo(0) + .with(new NeighborWithComponentIDJoin()) + .groupBy(0).min(1); + + DataSet> updatedComponentId = + verticesWithNewComponents.join(iteration).where(0).equalTo(0) + .flatMap(new MinimumIdFilterCounting(AGGREGATOR_NAME)); + + List> result = iteration.closeWith(updatedComponentId).collect(); + + Collections.sort(result, new JavaProgramTestBase.TupleComparator>()); + + List> expectedResult = Arrays.asList( + new Tuple2<>(1L,1L), + new Tuple2<>(2L,1L), + new Tuple2<>(3L,1L), + new Tuple2<>(4L,1L), + new Tuple2<>(5L,1L), + new Tuple2<>(6L,1L), + new Tuple2<>(7L,7L), + new Tuple2<>(8L,7L), + new Tuple2<>(9L,7L) + ); + + // checkpogram result + assertEquals(expectedResult, result); + + // check aggregators + long[] aggr_values = MinimumIdFilterCounting.aggr_value; + + // note that position 0 has the end result from superstep 1, retrieved at the start of iteration 2 + // position one as superstep 2, retrieved at the start of iteration 3. + // the result from iteration 5 is not available, because no iteration 6 happens + assertEquals(3, aggr_values[0]); + assertEquals(4, aggr_values[1]); + assertEquals(5, aggr_values[2]); + assertEquals(6, aggr_values[3]); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + // ------------------------------------------------------------------------ + // Test Functions + // ------------------------------------------------------------------------ + + public static final class NeighborWithComponentIDJoin extends RichJoinFunction, Tuple2, Tuple2> { + + private static final long serialVersionUID = 1L; + + @Override + public Tuple2 join(Tuple2 vertexWithCompId, Tuple2 edge) { + vertexWithCompId.f0 = edge.f1; + return vertexWithCompId; + } + } + + public static class MinimumIdFilter extends RichFlatMapFunction, Tuple2>, Tuple2> { + + private final String aggName; + private LongSumAggregator aggr; + + public MinimumIdFilter(String aggName) { + this.aggName = aggName; + } + + @Override + public void open(Configuration conf) { + aggr = getIterationRuntimeContext().getIterationAggregator(aggName); + } + + @Override + public void flatMap( + Tuple2, Tuple2> vertexWithNewAndOldId, + Collector> out) { + + if (vertexWithNewAndOldId.f0.f1 < vertexWithNewAndOldId.f1.f1) { + out.collect(vertexWithNewAndOldId.f0); + aggr.aggregate(1l); + } + else { + out.collect(vertexWithNewAndOldId.f1); + } + } + } + + public static final class MinimumIdFilterCounting + extends RichFlatMapFunction, Tuple2>, Tuple2> { + + private static final long[] aggr_value = new long[5]; + + private final String aggName; + private LongSumAggregatorWithParameter aggr; + + public MinimumIdFilterCounting(String aggName) { + this.aggName = aggName; + } + + @Override + public void open(Configuration conf) { + final int superstep = getIterationRuntimeContext().getSuperstepNumber(); + + aggr = getIterationRuntimeContext().getIterationAggregator(aggName); + + if (superstep > 1 && getIterationRuntimeContext().getIndexOfThisSubtask() == 0) { + LongValue val = getIterationRuntimeContext().getPreviousIterationAggregate(aggName); + aggr_value[superstep - 2] = val.getValue(); + } + } + + @Override + public void flatMap( + Tuple2, Tuple2> vertexWithNewAndOldId, + Collector> out) { + + if (vertexWithNewAndOldId.f0.f1 < vertexWithNewAndOldId.f1.f1) { + out.collect(vertexWithNewAndOldId.f0); + if (vertexWithNewAndOldId.f0.f1 == aggr.getComponentId()) { + aggr.aggregate(1l); + } + } else { + out.collect(vertexWithNewAndOldId.f1); + if (vertexWithNewAndOldId.f1.f1 == aggr.getComponentId()) { + aggr.aggregate(1l); + } + } + } + } + + /** A Convergence Criterion with one parameter */ + public static class UpdatedElementsConvergenceCriterion implements ConvergenceCriterion { + + private final long threshold; + + public UpdatedElementsConvergenceCriterion(long u_threshold) { + this.threshold = u_threshold; + } + + @Override + public boolean isConverged(int iteration, LongValue value) { + return value.getValue() < this.threshold; + } + } + + public static final class LongSumAggregatorWithParameter extends LongSumAggregator { + + private long componentId; + + public LongSumAggregatorWithParameter(long compId) { + this.componentId = compId; + } + + public long getComponentId() { + return this.componentId; + } + } +} diff --git a/flink-tests/src/test/java/org/apache/flink/test/iterative/aggregators/ConnectedComponentsWithParametrizableAggregatorITCase.java b/flink-tests/src/test/java/org/apache/flink/test/iterative/aggregators/ConnectedComponentsWithParametrizableAggregatorITCase.java deleted file mode 100644 index 8bf50de139115..0000000000000 --- a/flink-tests/src/test/java/org/apache/flink/test/iterative/aggregators/ConnectedComponentsWithParametrizableAggregatorITCase.java +++ /dev/null @@ -1,239 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.test.iterative.aggregators; - -import java.util.ArrayList; -import java.util.List; - -import org.apache.flink.api.common.aggregators.LongSumAggregator; -import org.apache.flink.api.common.functions.JoinFunction; -import org.apache.flink.api.common.functions.RichFlatMapFunction; -import org.apache.flink.api.common.functions.RichGroupReduceFunction; -import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.configuration.Configuration; -import org.apache.flink.test.util.JavaProgramTestBase; -import org.apache.flink.types.LongValue; -import org.apache.flink.util.Collector; -import org.junit.Assert; -import org.apache.flink.api.java.DataSet; -import org.apache.flink.api.java.ExecutionEnvironment; -import org.apache.flink.api.java.operators.IterativeDataSet; - -/** - * Connected Components test case that uses a parameterizable aggregator - */ -public class ConnectedComponentsWithParametrizableAggregatorITCase extends JavaProgramTestBase { - - private static final int MAX_ITERATIONS = 5; - private static final int parallelism = 1; - - protected static List> verticesInput = new ArrayList>(); - protected static List> edgesInput = new ArrayList>(); - private String resultPath; - private String expectedResult; - - @Override - protected void preSubmit() throws Exception { - // vertices input - verticesInput.clear(); - verticesInput.add(new Tuple2(1l,1l)); - verticesInput.add(new Tuple2(2l,2l)); - verticesInput.add(new Tuple2(3l,3l)); - verticesInput.add(new Tuple2(4l,4l)); - verticesInput.add(new Tuple2(5l,5l)); - verticesInput.add(new Tuple2(6l,6l)); - verticesInput.add(new Tuple2(7l,7l)); - verticesInput.add(new Tuple2(8l,8l)); - verticesInput.add(new Tuple2(9l,9l)); - - // vertices input - edgesInput.clear(); - edgesInput.add(new Tuple2(1l,2l)); - edgesInput.add(new Tuple2(1l,3l)); - edgesInput.add(new Tuple2(2l,3l)); - edgesInput.add(new Tuple2(2l,4l)); - edgesInput.add(new Tuple2(2l,1l)); - edgesInput.add(new Tuple2(3l,1l)); - edgesInput.add(new Tuple2(3l,2l)); - edgesInput.add(new Tuple2(4l,2l)); - edgesInput.add(new Tuple2(4l,6l)); - edgesInput.add(new Tuple2(5l,6l)); - edgesInput.add(new Tuple2(6l,4l)); - edgesInput.add(new Tuple2(6l,5l)); - edgesInput.add(new Tuple2(7l,8l)); - edgesInput.add(new Tuple2(7l,9l)); - edgesInput.add(new Tuple2(8l,7l)); - edgesInput.add(new Tuple2(8l,9l)); - edgesInput.add(new Tuple2(9l,7l)); - edgesInput.add(new Tuple2(9l,8l)); - - resultPath = getTempDirPath("result"); - - expectedResult = "(1,1)\n" + "(2,1)\n" + "(3,1)\n" + "(4,1)\n" + - "(5,1)\n" + "(6,1)\n" + "(7,7)\n" + "(8,7)\n" + "(9,7)\n"; - } - - @Override - protected void testProgram() throws Exception { - ConnectedComponentsWithAggregatorProgram.runProgram(resultPath); - } - - @Override - protected void postSubmit() throws Exception { - compareResultsByLinesInMemory(expectedResult, resultPath); - long[] aggr_values = ConnectedComponentsWithAggregatorProgram.aggr_value; - - // note that position 0 has the end result from superstep 1, retrieved at the start of iteration 2 - // position one as superstep 2, retrieved at the start of iteration 3. - // the result from iteration 5 is not available, because no iteration 6 happens - Assert.assertEquals(3, aggr_values[0]); - Assert.assertEquals(4, aggr_values[1]); - Assert.assertEquals(5, aggr_values[2]); - Assert.assertEquals(6, aggr_values[3]); - } - - - private static class ConnectedComponentsWithAggregatorProgram { - - private static final String ELEMENTS_IN_COMPONENT = "elements.in.component.aggregator"; - private static final long componentId = 1l; - private static long [] aggr_value = new long [MAX_ITERATIONS]; - - public static String runProgram(String resultPath) throws Exception { - - final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); - env.setParallelism(parallelism); - - DataSet> initialSolutionSet = env.fromCollection(verticesInput); - DataSet> edges = env.fromCollection(edgesInput); - - IterativeDataSet> iteration = - initialSolutionSet.iterate(MAX_ITERATIONS); - - // register the aggregator - iteration.registerAggregator(ELEMENTS_IN_COMPONENT, new LongSumAggregatorWithParameter(componentId)); - - DataSet> verticesWithNewComponents = iteration.join(edges).where(0).equalTo(0) - .with(new NeighborWithComponentIDJoin()) - .groupBy(0).reduceGroup(new MinimumReduce()); - - DataSet> updatedComponentId = - verticesWithNewComponents.join(iteration).where(0).equalTo(0) - .flatMap(new MinimumIdFilter()); - - iteration.closeWith(updatedComponentId).writeAsText(resultPath); - - env.execute(); - - return resultPath; - } - } - - public static final class NeighborWithComponentIDJoin implements JoinFunction - , Tuple2, Tuple2> { - - private static final long serialVersionUID = 1L; - - @Override - public Tuple2 join(Tuple2 vertexWithCompId, - Tuple2 edge) throws Exception { - - vertexWithCompId.setField(edge.f1, 0); - return vertexWithCompId; - } - } - - public static final class MinimumReduce extends RichGroupReduceFunction, Tuple2> { - - private static final long serialVersionUID = 1L; - - private final Tuple2 resultVertex = new Tuple2(); - - @Override - public void reduce(Iterable> values, Collector> out) { - Long vertexId = 0L; - Long minimumCompId = Long.MAX_VALUE; - - for (Tuple2 value: values) { - vertexId = value.f0; - Long candidateCompId = value.f1; - if (candidateCompId < minimumCompId) { - minimumCompId = candidateCompId; - } - } - resultVertex.f0 = vertexId; - resultVertex.f1 = minimumCompId; - - out.collect(resultVertex); - } - } - - @SuppressWarnings("serial") - public static final class MinimumIdFilter extends RichFlatMapFunction, Tuple2>, Tuple2> { - - private static LongSumAggregatorWithParameter aggr; - - @Override - public void open(Configuration conf) { - aggr = getIterationRuntimeContext().getIterationAggregator( - ConnectedComponentsWithAggregatorProgram.ELEMENTS_IN_COMPONENT); - - int superstep = getIterationRuntimeContext().getSuperstepNumber(); - - if (superstep > 1) { - LongValue val = getIterationRuntimeContext().getPreviousIterationAggregate( - ConnectedComponentsWithAggregatorProgram.ELEMENTS_IN_COMPONENT); - ConnectedComponentsWithAggregatorProgram.aggr_value[superstep-2] = val.getValue(); - } - } - - @Override - public void flatMap( - Tuple2, Tuple2> vertexWithNewAndOldId, - Collector> out) throws Exception { - - if (vertexWithNewAndOldId.f0.f1 < vertexWithNewAndOldId.f1.f1) { - out.collect(vertexWithNewAndOldId.f0); - if (vertexWithNewAndOldId.f0.f1 == aggr.getComponentId()) { - aggr.aggregate(1l); - } - } else { - out.collect(vertexWithNewAndOldId.f1); - if (vertexWithNewAndOldId.f1.f1 == aggr.getComponentId()) { - aggr.aggregate(1l); - } - } - } - } - - // A LongSumAggregator with one parameter - @SuppressWarnings("serial") - public static final class LongSumAggregatorWithParameter extends LongSumAggregator { - - private long componentId; - - public LongSumAggregatorWithParameter(long compId) { - this.componentId = compId; - } - - public long getComponentId() { - return this.componentId; - } - } -} diff --git a/flink-tests/src/test/java/org/apache/flink/test/iterative/aggregators/ConnectedComponentsWithParametrizableConvergenceITCase.java b/flink-tests/src/test/java/org/apache/flink/test/iterative/aggregators/ConnectedComponentsWithParametrizableConvergenceITCase.java deleted file mode 100644 index e616a2b952dad..0000000000000 --- a/flink-tests/src/test/java/org/apache/flink/test/iterative/aggregators/ConnectedComponentsWithParametrizableConvergenceITCase.java +++ /dev/null @@ -1,224 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.test.iterative.aggregators; - -import java.util.ArrayList; -import java.util.List; - -import org.apache.flink.api.common.aggregators.ConvergenceCriterion; -import org.apache.flink.api.common.aggregators.LongSumAggregator; -import org.apache.flink.api.common.functions.RichFlatMapFunction; -import org.apache.flink.api.common.functions.RichGroupReduceFunction; -import org.apache.flink.api.common.functions.RichJoinFunction; -import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.configuration.Configuration; -import org.apache.flink.test.util.JavaProgramTestBase; -import org.apache.flink.types.LongValue; -import org.apache.flink.util.Collector; -import org.apache.flink.api.java.DataSet; -import org.apache.flink.api.java.ExecutionEnvironment; -import org.apache.flink.api.java.operators.IterativeDataSet; - - -/** - * - * Connected Components test case that uses a parametrizable convergence criterion - * - */ -public class ConnectedComponentsWithParametrizableConvergenceITCase extends JavaProgramTestBase { - - private static final int MAX_ITERATIONS = 10; - private static final int parallelism = 1; - - protected static List> verticesInput = new ArrayList>(); - protected static List> edgesInput = new ArrayList>(); - private String resultPath; - private String expectedResult; - - @Override - protected void preSubmit() throws Exception { - // vertices input - verticesInput.clear(); - verticesInput.add(new Tuple2(1l,1l)); - verticesInput.add(new Tuple2(2l,2l)); - verticesInput.add(new Tuple2(3l,3l)); - verticesInput.add(new Tuple2(4l,4l)); - verticesInput.add(new Tuple2(5l,5l)); - verticesInput.add(new Tuple2(6l,6l)); - verticesInput.add(new Tuple2(7l,7l)); - verticesInput.add(new Tuple2(8l,8l)); - verticesInput.add(new Tuple2(9l,9l)); - - // vertices input - edgesInput.clear(); - edgesInput.add(new Tuple2(1l,2l)); - edgesInput.add(new Tuple2(1l,3l)); - edgesInput.add(new Tuple2(2l,3l)); - edgesInput.add(new Tuple2(2l,4l)); - edgesInput.add(new Tuple2(2l,1l)); - edgesInput.add(new Tuple2(3l,1l)); - edgesInput.add(new Tuple2(3l,2l)); - edgesInput.add(new Tuple2(4l,2l)); - edgesInput.add(new Tuple2(4l,6l)); - edgesInput.add(new Tuple2(5l,6l)); - edgesInput.add(new Tuple2(6l,4l)); - edgesInput.add(new Tuple2(6l,5l)); - edgesInput.add(new Tuple2(7l,8l)); - edgesInput.add(new Tuple2(7l,9l)); - edgesInput.add(new Tuple2(8l,7l)); - edgesInput.add(new Tuple2(8l,9l)); - edgesInput.add(new Tuple2(9l,7l)); - edgesInput.add(new Tuple2(9l,8l)); - - resultPath = getTempDirPath("result"); - - expectedResult = "(1,1)\n" + "(2,1)\n" + "(3,1)\n" + "(4,1)\n" + - "(5,2)\n" + "(6,1)\n" + "(7,7)\n" + "(8,7)\n" + "(9,7)\n"; - } - - @Override - protected void testProgram() throws Exception { - ConnectedComponentsWithConvergenceProgram.runProgram(resultPath); - } - - @Override - protected void postSubmit() throws Exception { - compareResultsByLinesInMemory(expectedResult, resultPath); - } - - - private static class ConnectedComponentsWithConvergenceProgram { - - private static final String UPDATED_ELEMENTS = "updated.elements.aggr"; - private static final long convergence_threshold = 3; // the iteration stops if less than this number os elements change value - - public static String runProgram(String resultPath) throws Exception { - - final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); - env.setParallelism(parallelism); - - DataSet> initialSolutionSet = env.fromCollection(verticesInput); - DataSet> edges = env.fromCollection(edgesInput); - - IterativeDataSet> iteration = - initialSolutionSet.iterate(MAX_ITERATIONS); - - // register the convergence criterion - iteration.registerAggregationConvergenceCriterion(UPDATED_ELEMENTS, - new LongSumAggregator(), new UpdatedElementsConvergenceCriterion(convergence_threshold)); - - DataSet> verticesWithNewComponents = iteration.join(edges).where(0).equalTo(0) - .with(new NeighborWithComponentIDJoin()) - .groupBy(0).reduceGroup(new MinimumReduce()); - - DataSet> updatedComponentId = - verticesWithNewComponents.join(iteration).where(0).equalTo(0) - .flatMap(new MinimumIdFilter()); - - iteration.closeWith(updatedComponentId).writeAsText(resultPath); - - env.execute(); - - return resultPath; - } - } - - public static final class NeighborWithComponentIDJoin extends RichJoinFunction, Tuple2, Tuple2> { - - private static final long serialVersionUID = 1L; - - @Override - public Tuple2 join(Tuple2 vertexWithCompId, - Tuple2 edge) throws Exception { - - vertexWithCompId.setField(edge.f1, 0); - return vertexWithCompId; - } - } - - public static final class MinimumReduce extends RichGroupReduceFunction, Tuple2> { - - private static final long serialVersionUID = 1L; - final Tuple2 resultVertex = new Tuple2(); - - @Override - public void reduce(Iterable> values, Collector> out) { - Long vertexId = 0L; - Long minimumCompId = Long.MAX_VALUE; - - for (Tuple2 value: values) { - vertexId = value.f0; - Long candidateCompId = value.f1; - if (candidateCompId < minimumCompId) { - minimumCompId = candidateCompId; - } - } - resultVertex.f0 = vertexId; - resultVertex.f1 = minimumCompId; - - out.collect(resultVertex); - } - } - - @SuppressWarnings("serial") - public static final class MinimumIdFilter extends RichFlatMapFunction, Tuple2>, Tuple2> { - - private static LongSumAggregator aggr; - - @Override - public void open(Configuration conf) { - aggr = getIterationRuntimeContext().getIterationAggregator( - ConnectedComponentsWithConvergenceProgram.UPDATED_ELEMENTS); - } - - @Override - public void flatMap( - Tuple2, Tuple2> vertexWithNewAndOldId, - Collector> out) throws Exception { - - if (vertexWithNewAndOldId.f0.f1 < vertexWithNewAndOldId.f1.f1) { - out.collect(vertexWithNewAndOldId.f0); - aggr.aggregate(1l); - } else { - out.collect(vertexWithNewAndOldId.f1); - } - } - } - - // A Convergence Criterion with one parameter - @SuppressWarnings("serial") - public static final class UpdatedElementsConvergenceCriterion implements ConvergenceCriterion { - - private long threshold; - - public UpdatedElementsConvergenceCriterion(long u_threshold) { - this.threshold = u_threshold; - } - - public long getThreshold() { - return this.threshold; - } - - @Override - public boolean isConverged(int iteration, LongValue value) { - return value.getValue() < this.threshold; - } - } - -}