Skip to content

Commit

Permalink
[SPARK-5843] [API] Allowing map-side combine to be specified in Java.
Browse files Browse the repository at this point in the history
Specifically, when calling JavaPairRDD.combineByKey(), there is a new
six-parameter method that exposes the map-side-combine boolean as the
fifth parameter and the serializer as the sixth parameter.

Author: mcheah <[email protected]>

Closes apache#4634 from mccheah/pair-rdd-map-side-combine and squashes the following commits:

5c58319 [mcheah] Fixing compiler errors.
3ce7deb [mcheah] Addressing style and documentation comments.
7455c7a [mcheah] Allowing Java combineByKey to specify Serializer as well.
6ddd729 [mcheah] [SPARK-5843] Allowing map-side combine to be specified in Java.
  • Loading branch information
mccheah authored and srowen committed Mar 19, 2015
1 parent 797f8a0 commit 3c4e486
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 12 deletions.
46 changes: 37 additions & 9 deletions core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import org.apache.spark.api.java.function.{Function => JFunction, Function2 => J
import org.apache.spark.partial.{BoundedDouble, PartialResult}
import org.apache.spark.rdd.{OrderedRDDFunctions, RDD}
import org.apache.spark.rdd.RDD.rddToPairRDDFunctions
import org.apache.spark.serializer.Serializer
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -227,24 +228,51 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
* - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list)
* - `mergeCombiners`, to combine two C's into a single one.
*
* In addition, users can control the partitioning of the output RDD, and whether to perform
* map-side aggregation (if a mapper can produce multiple items with the same key).
* In addition, users can control the partitioning of the output RDD, the serializer that is use
* for the shuffle, and whether to perform map-side aggregation (if a mapper can produce multiple
* items with the same key).
*/
def combineByKey[C](createCombiner: JFunction[V, C],
mergeValue: JFunction2[C, V, C],
mergeCombiners: JFunction2[C, C, C],
partitioner: Partitioner): JavaPairRDD[K, C] = {
implicit val ctag: ClassTag[C] = fakeClassTag
mergeValue: JFunction2[C, V, C],
mergeCombiners: JFunction2[C, C, C],
partitioner: Partitioner,
mapSideCombine: Boolean,
serializer: Serializer): JavaPairRDD[K, C] = {
implicit val ctag: ClassTag[C] = fakeClassTag
fromRDD(rdd.combineByKey(
createCombiner,
mergeValue,
mergeCombiners,
partitioner
partitioner,
mapSideCombine,
serializer
))
}

/**
* Simplified version of combineByKey that hash-partitions the output RDD.
* Generic function to combine the elements for each key using a custom set of aggregation
* functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a
* "combined type" C * Note that V and C can be different -- for example, one might group an
* RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three
* functions:
*
* - `createCombiner`, which turns a V into a C (e.g., creates a one-element list)
* - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list)
* - `mergeCombiners`, to combine two C's into a single one.
*
* In addition, users can control the partitioning of the output RDD. This method automatically
* uses map-side aggregation in shuffling the RDD.
*/
def combineByKey[C](createCombiner: JFunction[V, C],
mergeValue: JFunction2[C, V, C],
mergeCombiners: JFunction2[C, C, C],
partitioner: Partitioner): JavaPairRDD[K, C] = {
combineByKey(createCombiner, mergeValue, mergeCombiners, partitioner, true, null)
}

/**
* Simplified version of combineByKey that hash-partitions the output RDD and uses map-side
* aggregation.
*/
def combineByKey[C](createCombiner: JFunction[V, C],
mergeValue: JFunction2[C, V, C],
Expand Down Expand Up @@ -488,7 +516,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])

/**
* Simplified version of combineByKey that hash-partitions the resulting RDD using the existing
* partitioner/parallelism level.
* partitioner/parallelism level and using map-side aggregation.
*/
def combineByKey[C](createCombiner: JFunction[V, C],
mergeValue: JFunction2[C, V, C],
Expand Down
53 changes: 50 additions & 3 deletions core/src/test/java/org/apache/spark/JavaAPISuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@
import java.util.*;
import java.util.concurrent.*;

import org.apache.spark.input.PortableDataStream;
import scala.collection.JavaConversions;
import scala.Tuple2;
import scala.Tuple3;
import scala.Tuple4;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import com.google.common.collect.Iterators;
import com.google.common.collect.Lists;
Expand All @@ -51,8 +52,11 @@
import org.apache.spark.api.java.*;
import org.apache.spark.api.java.function.*;
import org.apache.spark.executor.TaskMetrics;
import org.apache.spark.input.PortableDataStream;
import org.apache.spark.partial.BoundedDouble;
import org.apache.spark.partial.PartialResult;
import org.apache.spark.rdd.RDD;
import org.apache.spark.serializer.KryoSerializer;
import org.apache.spark.storage.StorageLevel;
import org.apache.spark.util.StatCounter;

Expand Down Expand Up @@ -726,8 +730,8 @@ public void javaDoubleRDDHistoGram() {
Tuple2<double[], long[]> results = rdd.histogram(2);
double[] expected_buckets = {1.0, 2.5, 4.0};
long[] expected_counts = {2, 2};
Assert.assertArrayEquals(expected_buckets, results._1, 0.1);
Assert.assertArrayEquals(expected_counts, results._2);
Assert.assertArrayEquals(expected_buckets, results._1(), 0.1);
Assert.assertArrayEquals(expected_counts, results._2());
// Test with provided buckets
long[] histogram = rdd.histogram(expected_buckets);
Assert.assertArrayEquals(expected_counts, histogram);
Expand Down Expand Up @@ -1424,6 +1428,49 @@ public void checkpointAndRestore() {
Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), recovered.collect());
}

@Test
public void combineByKey() {
JavaRDD<Integer> originalRDD = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6));
Function<Integer, Integer> keyFunction = new Function<Integer, Integer>() {
@Override
public Integer call(Integer v1) throws Exception {
return v1 % 3;
}
};
Function<Integer, Integer> createCombinerFunction = new Function<Integer, Integer>() {
@Override
public Integer call(Integer v1) throws Exception {
return v1;
}
};

Function2<Integer, Integer, Integer> mergeValueFunction = new Function2<Integer, Integer, Integer>() {
@Override
public Integer call(Integer v1, Integer v2) throws Exception {
return v1 + v2;
}
};

JavaPairRDD<Integer, Integer> combinedRDD = originalRDD.keyBy(keyFunction)
.combineByKey(createCombinerFunction, mergeValueFunction, mergeValueFunction);
Map<Integer, Integer> results = combinedRDD.collectAsMap();
ImmutableMap<Integer, Integer> expected = ImmutableMap.of(0, 9, 1, 5, 2, 7);
Assert.assertEquals(expected, results);

Partitioner defaultPartitioner = Partitioner.defaultPartitioner(
combinedRDD.rdd(), JavaConversions.asScalaBuffer(Lists.<RDD<?>>newArrayList()));
combinedRDD = originalRDD.keyBy(keyFunction)
.combineByKey(
createCombinerFunction,
mergeValueFunction,
mergeValueFunction,
defaultPartitioner,
false,
new KryoSerializer(new SparkConf()));
results = combinedRDD.collectAsMap();
Assert.assertEquals(expected, results);
}

@SuppressWarnings("unchecked")
@Test
public void mapOnPairRDD() {
Expand Down

0 comments on commit 3c4e486

Please sign in to comment.