Skip to content

Commit

Permalink
[SPARK-1540] Add an optional Ordering parameter to PairRDDFunctions.
Browse files Browse the repository at this point in the history
In https://issues.apache.org/jira/browse/SPARK-1540 we'd like to look at Spark's API to see if we can take advantage of Comparable keys in more places, which will make external spilling more efficient. This PR is a first step towards that that shows how to pass an Ordering when available and still continue functioning otherwise. It does this using a new implicit parameter with a default value of null.

The API is currently only in Scala -- in Java we'd have to add new versions of mapToPair and such that take a Comparator, or a new method to add a "type hint" to an RDD. We can address those later though.

Unfortunately requiring all keys to be Comparable would not work without requiring RDDs in general to contain only Comparable types. The reason is that methods such as distinct() and intersection() do a shuffle, but should be usable on RDDs of any type. So ordering will have to remain an optimization for the types that can be ordered. I think this isn't a horrible outcome though because one of the nice things about Spark's API is that it works on objects of *any* type, without requiring you to specify a schema or implement Writable or stuff like that.

Author: Matei Zaharia <[email protected]>

This patch had conflicts when merged, resolved by
Committer: Reynold Xin <[email protected]>

Closes #487 from mateiz/ordered-keys and squashes the following commits:

bd565f6 [Matei Zaharia] Pass an Ordering to only one version of groupBy because the Scala language spec doesn't allow having an optional parameter on all of them (this was only compiling in Scala 2.10 due to a bug).
4629965 [Matei Zaharia] Add tests for other versions of groupBy
3beae85 [Matei Zaharia] Added a test for implicit orderings
80b7a3b [Matei Zaharia] Add an optional Ordering parameter to PairRDDFunctions.
  • Loading branch information
mateiz authored and rxin committed Apr 24, 2014
1 parent 432201c commit 640f9a0
Show file tree
Hide file tree
Showing 8 changed files with 124 additions and 49 deletions.
4 changes: 3 additions & 1 deletion core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1281,8 +1281,10 @@ object SparkContext extends Logging {

// TODO: Add AccumulatorParams for other types, e.g. lists and strings

implicit def rddToPairRDDFunctions[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]) =
implicit def rddToPairRDDFunctions[K, V](rdd: RDD[(K, V)])
(implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null) = {
new PairRDDFunctions(rdd)
}

implicit def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]) = new AsyncRDDActions(rdd)

Expand Down
29 changes: 16 additions & 13 deletions core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,12 @@ import org.apache.spark.util.SerializableHyperLogLog
* Extra functions available on RDDs of (key, value) pairs through an implicit conversion.
* Import `org.apache.spark.SparkContext._` at the top of your program to use these functions.
*/
class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
class PairRDDFunctions[K, V](self: RDD[(K, V)])
(implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null)
extends Logging
with SparkHadoopMapReduceUtil
with Serializable {

with Serializable
{
/**
* Generic function to combine the elements for each key using a custom set of aggregation
* functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined type" C
Expand All @@ -77,7 +78,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
mapSideCombine: Boolean = true,
serializer: Serializer = null): RDD[(K, C)] = {
require(mergeCombiners != null, "mergeCombiners must be defined") // required as of Spark 0.9.0
if (getKeyClass().isArray) {
if (keyClass.isArray) {
if (mapSideCombine) {
throw new SparkException("Cannot use map-side combining with array keys.")
}
Expand Down Expand Up @@ -170,7 +171,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
*/
def reduceByKeyLocally(func: (V, V) => V): Map[K, V] = {

if (getKeyClass().isArray) {
if (keyClass.isArray) {
throw new SparkException("reduceByKeyLocally() does not support array keys")
}

Expand Down Expand Up @@ -288,7 +289,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
* Return a copy of the RDD partitioned using the specified partitioner.
*/
def partitionBy(partitioner: Partitioner): RDD[(K, V)] = {
if (getKeyClass().isArray && partitioner.isInstanceOf[HashPartitioner]) {
if (keyClass.isArray && partitioner.isInstanceOf[HashPartitioner]) {
throw new SparkException("Default partitioner cannot partition array keys.")
}
if (self.partitioner == partitioner) self else new ShuffledRDD[K, V, (K, V)](self, partitioner)
Expand Down Expand Up @@ -458,7 +459,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
*/
def cogroup[W](other: RDD[(K, W)], partitioner: Partitioner)
: RDD[(K, (Iterable[V], Iterable[W]))] = {
if (partitioner.isInstanceOf[HashPartitioner] && getKeyClass().isArray) {
if (partitioner.isInstanceOf[HashPartitioner] && keyClass.isArray) {
throw new SparkException("Default partitioner cannot partition array keys.")
}
val cg = new CoGroupedRDD[K](Seq(self, other), partitioner)
Expand All @@ -473,7 +474,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
*/
def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], partitioner: Partitioner)
: RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2]))] = {
if (partitioner.isInstanceOf[HashPartitioner] && getKeyClass().isArray) {
if (partitioner.isInstanceOf[HashPartitioner] && keyClass.isArray) {
throw new SparkException("Default partitioner cannot partition array keys.")
}
val cg = new CoGroupedRDD[K](Seq(self, other1, other2), partitioner)
Expand Down Expand Up @@ -573,7 +574,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
* supporting the key and value types K and V in this RDD.
*/
def saveAsHadoopFile[F <: OutputFormat[K, V]](path: String)(implicit fm: ClassTag[F]) {
saveAsHadoopFile(path, getKeyClass, getValueClass, fm.runtimeClass.asInstanceOf[Class[F]])
saveAsHadoopFile(path, keyClass, valueClass, fm.runtimeClass.asInstanceOf[Class[F]])
}

/**
Expand All @@ -584,15 +585,15 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
def saveAsHadoopFile[F <: OutputFormat[K, V]](
path: String, codec: Class[_ <: CompressionCodec]) (implicit fm: ClassTag[F]) {
val runtimeClass = fm.runtimeClass
saveAsHadoopFile(path, getKeyClass, getValueClass, runtimeClass.asInstanceOf[Class[F]], codec)
saveAsHadoopFile(path, keyClass, valueClass, runtimeClass.asInstanceOf[Class[F]], codec)
}

/**
* Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat`
* (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD.
*/
def saveAsNewAPIHadoopFile[F <: NewOutputFormat[K, V]](path: String)(implicit fm: ClassTag[F]) {
saveAsNewAPIHadoopFile(path, getKeyClass, getValueClass, fm.runtimeClass.asInstanceOf[Class[F]])
saveAsNewAPIHadoopFile(path, keyClass, valueClass, fm.runtimeClass.asInstanceOf[Class[F]])
}

/**
Expand Down Expand Up @@ -782,7 +783,9 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
*/
def values: RDD[V] = self.map(_._2)

private[spark] def getKeyClass() = implicitly[ClassTag[K]].runtimeClass
private[spark] def keyClass: Class[_] = kt.runtimeClass

private[spark] def valueClass: Class[_] = vt.runtimeClass

private[spark] def getValueClass() = implicitly[ClassTag[V]].runtimeClass
private[spark] def keyOrdering: Option[Ordering[K]] = Option(ord)
}
46 changes: 27 additions & 19 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ abstract class RDD[T: ClassTag](
/**
* Return a new RDD containing the distinct elements in this RDD.
*/
def distinct(numPartitions: Int): RDD[T] =
def distinct(numPartitions: Int)(implicit ord: Ordering[T] = null): RDD[T] =
map(x => (x, null)).reduceByKey((x, y) => x, numPartitions).map(_._1)

/**
Expand All @@ -301,7 +301,7 @@ abstract class RDD[T: ClassTag](
* If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
* which can avoid performing a shuffle.
*/
def repartition(numPartitions: Int): RDD[T] = {
def repartition(numPartitions: Int)(implicit ord: Ordering[T] = null): RDD[T] = {
coalesce(numPartitions, shuffle = true)
}

Expand All @@ -325,7 +325,8 @@ abstract class RDD[T: ClassTag](
* coalesce(1000, shuffle = true) will result in 1000 partitions with the
* data distributed using a hash partitioner.
*/
def coalesce(numPartitions: Int, shuffle: Boolean = false): RDD[T] = {
def coalesce(numPartitions: Int, shuffle: Boolean = false)(implicit ord: Ordering[T] = null)
: RDD[T] = {
if (shuffle) {
// include a shuffle step so that our upstream tasks are still distributed
new CoalescedRDD(
Expand Down Expand Up @@ -424,10 +425,11 @@ abstract class RDD[T: ClassTag](
*
* Note that this method performs a shuffle internally.
*/
def intersection(other: RDD[T]): RDD[T] =
def intersection(other: RDD[T]): RDD[T] = {
this.map(v => (v, null)).cogroup(other.map(v => (v, null)))
.filter { case (_, (leftGroup, rightGroup)) => leftGroup.nonEmpty && rightGroup.nonEmpty }
.keys
}

/**
* Return the intersection of this RDD and another one. The output will not contain any duplicate
Expand All @@ -437,10 +439,12 @@ abstract class RDD[T: ClassTag](
*
* @param partitioner Partitioner to use for the resulting RDD
*/
def intersection(other: RDD[T], partitioner: Partitioner): RDD[T] =
def intersection(other: RDD[T], partitioner: Partitioner)(implicit ord: Ordering[T] = null)
: RDD[T] = {
this.map(v => (v, null)).cogroup(other.map(v => (v, null)), partitioner)
.filter { case (_, (leftGroup, rightGroup)) => leftGroup.nonEmpty && rightGroup.nonEmpty }
.keys
}

/**
* Return the intersection of this RDD and another one. The output will not contain any duplicate
Expand All @@ -450,10 +454,11 @@ abstract class RDD[T: ClassTag](
*
* @param numPartitions How many partitions to use in the resulting RDD
*/
def intersection(other: RDD[T], numPartitions: Int): RDD[T] =
def intersection(other: RDD[T], numPartitions: Int): RDD[T] = {
this.map(v => (v, null)).cogroup(other.map(v => (v, null)), new HashPartitioner(numPartitions))
.filter { case (_, (leftGroup, rightGroup)) => leftGroup.nonEmpty && rightGroup.nonEmpty }
.keys
}

/**
* Return an RDD created by coalescing all elements within each partition into an array.
Expand All @@ -467,22 +472,25 @@ abstract class RDD[T: ClassTag](
def cartesian[U: ClassTag](other: RDD[U]): RDD[(T, U)] = new CartesianRDD(sc, this, other)

/**
* Return an RDD of grouped items.
* Return an RDD of grouped items. Each group consists of a key and a sequence of elements
* mapping to that key.
*/
def groupBy[K: ClassTag](f: T => K): RDD[(K, Iterable[T])] =
def groupBy[K](f: T => K)(implicit kt: ClassTag[K]): RDD[(K, Iterable[T])] =
groupBy[K](f, defaultPartitioner(this))

/**
* Return an RDD of grouped elements. Each group consists of a key and a sequence of elements
* mapping to that key.
*/
def groupBy[K: ClassTag](f: T => K, numPartitions: Int): RDD[(K, Iterable[T])] =
def groupBy[K](f: T => K, numPartitions: Int)(implicit kt: ClassTag[K]): RDD[(K, Iterable[T])] =
groupBy(f, new HashPartitioner(numPartitions))

/**
* Return an RDD of grouped items.
* Return an RDD of grouped items. Each group consists of a key and a sequence of elements
* mapping to that key.
*/
def groupBy[K: ClassTag](f: T => K, p: Partitioner): RDD[(K, Iterable[T])] = {
def groupBy[K](f: T => K, p: Partitioner)(implicit kt: ClassTag[K], ord: Ordering[K] = null)
: RDD[(K, Iterable[T])] = {
val cleanF = sc.clean(f)
this.map(t => (cleanF(t), t)).groupByKey(p)
}
Expand Down Expand Up @@ -739,7 +747,7 @@ abstract class RDD[T: ClassTag](
/**
* Return an RDD with the elements from `this` that are not in `other`.
*/
def subtract(other: RDD[T], p: Partitioner): RDD[T] = {
def subtract(other: RDD[T], p: Partitioner)(implicit ord: Ordering[T] = null): RDD[T] = {
if (partitioner == Some(p)) {
// Our partitioner knows how to handle T (which, since we have a partitioner, is
// really (K, V)) so make a new Partitioner that will de-tuple our fake tuples
Expand Down Expand Up @@ -847,7 +855,7 @@ abstract class RDD[T: ClassTag](
* Return the count of each unique value in this RDD as a map of (value, count) pairs. The final
* combine step happens locally on the master, equivalent to running a single reduce task.
*/
def countByValue(): Map[T, Long] = {
def countByValue()(implicit ord: Ordering[T] = null): Map[T, Long] = {
if (elementClassTag.runtimeClass.isArray) {
throw new SparkException("countByValue() does not support arrays")
}
Expand Down Expand Up @@ -877,10 +885,10 @@ abstract class RDD[T: ClassTag](
* Approximate version of countByValue().
*/
@Experimental
def countByValueApprox(
timeout: Long,
confidence: Double = 0.95
): PartialResult[Map[T, BoundedDouble]] = {
def countByValueApprox(timeout: Long, confidence: Double = 0.95)
(implicit ord: Ordering[T] = null)
: PartialResult[Map[T, BoundedDouble]] =
{
if (elementClassTag.runtimeClass.isArray) {
throw new SparkException("countByValueApprox() does not support arrays")
}
Expand Down Expand Up @@ -1030,13 +1038,13 @@ abstract class RDD[T: ClassTag](
* Returns the max of this RDD as defined by the implicit Ordering[T].
* @return the maximum element of the RDD
* */
def max()(implicit ord: Ordering[T]):T = this.reduce(ord.max)
def max()(implicit ord: Ordering[T]): T = this.reduce(ord.max)

/**
* Returns the min of this RDD as defined by the implicit Ordering[T].
* @return the minimum element of the RDD
* */
def min()(implicit ord: Ordering[T]):T = this.reduce(ord.min)
def min()(implicit ord: Ordering[T]): T = this.reduce(ord.min)

/**
* Save this RDD as a text file, using string representations of elements.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag

val keyClass = getWritableClass[K]
val valueClass = getWritableClass[V]
val convertKey = !classOf[Writable].isAssignableFrom(self.getKeyClass)
val convertValue = !classOf[Writable].isAssignableFrom(self.getValueClass)
val convertKey = !classOf[Writable].isAssignableFrom(self.keyClass)
val convertValue = !classOf[Writable].isAssignableFrom(self.valueClass)

logInfo("Saving as sequence file of type (" + keyClass.getSimpleName + "," +
valueClass.getSimpleName + ")" )
Expand Down
57 changes: 57 additions & 0 deletions core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark

import org.scalatest.FunSuite

import org.apache.spark.SparkContext._

class ImplicitOrderingSuite extends FunSuite with LocalSparkContext {
class NonOrderedClass {}

class ComparableClass extends Comparable[ComparableClass] {
override def compareTo(o: ComparableClass): Int = ???
}

class OrderedClass extends Ordered[OrderedClass] {
override def compare(o: OrderedClass): Int = ???
}

// Tests that PairRDDFunctions grabs an implicit Ordering in various cases where it should.
test("basic inference of Orderings"){
sc = new SparkContext("local", "test")
val rdd = sc.parallelize(1 to 10)

// Infer orderings after basic maps to particular types
assert(rdd.map(x => (x, x)).keyOrdering.isDefined)
assert(rdd.map(x => (1, x)).keyOrdering.isDefined)
assert(rdd.map(x => (x.toString, x)).keyOrdering.isDefined)
assert(rdd.map(x => (null, x)).keyOrdering.isDefined)
assert(rdd.map(x => (new NonOrderedClass, x)).keyOrdering.isEmpty)
assert(rdd.map(x => (new ComparableClass, x)).keyOrdering.isDefined)
assert(rdd.map(x => (new OrderedClass, x)).keyOrdering.isDefined)

// Infer orderings for other RDD methods
assert(rdd.groupBy(x => x).keyOrdering.isDefined)
assert(rdd.groupBy(x => new NonOrderedClass).keyOrdering.isEmpty)
assert(rdd.groupBy(x => new ComparableClass).keyOrdering.isDefined)
assert(rdd.groupBy(x => new OrderedClass).keyOrdering.isDefined)
assert(rdd.groupBy((x: Int) => x, 5).keyOrdering.isDefined)
assert(rdd.groupBy((x: Int) => x, new HashPartitioner(5)).keyOrdering.isDefined)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,10 @@ class StreamingContext private[streaming] (

object StreamingContext extends Logging {

implicit def toPairDStreamFunctions[K: ClassTag, V: ClassTag](stream: DStream[(K,V)]) = {
private[streaming] val DEFAULT_CLEANER_TTL = 3600

implicit def toPairDStreamFunctions[K, V](stream: DStream[(K, V)])
(implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null) = {
new PairDStreamFunctions[K, V](stream)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,8 @@ abstract class DStream[T: ClassTag] (
* the RDDs with `numPartitions` partitions (Spark's default number of partitions if
* `numPartitions` not specified).
*/
def countByValue(numPartitions: Int = ssc.sc.defaultParallelism): DStream[(T, Long)] =
def countByValue(numPartitions: Int = ssc.sc.defaultParallelism)(implicit ord: Ordering[T] = null)
: DStream[(T, Long)] =
this.map(x => (x, 1L)).reduceByKey((x: Long, y: Long) => x + y, numPartitions)

/**
Expand Down Expand Up @@ -686,9 +687,10 @@ abstract class DStream[T: ClassTag] (
def countByValueAndWindow(
windowDuration: Duration,
slideDuration: Duration,
numPartitions: Int = ssc.sc.defaultParallelism
): DStream[(T, Long)] = {

numPartitions: Int = ssc.sc.defaultParallelism)
(implicit ord: Ordering[T] = null)
: DStream[(T, Long)] =
{
this.map(x => (x, 1L)).reduceByKeyAndWindow(
(x: Long, y: Long) => x + y,
(x: Long, y: Long) => x - y,
Expand Down
Loading

0 comments on commit 640f9a0

Please sign in to comment.