Skip to content

Commit

Permalink
Made RDD interface backwards-compatibile. Reverted many RDDs.
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Wroblewski committed Sep 16, 2015
1 parent eaa7701 commit 7401505
Show file tree
Hide file tree
Showing 44 changed files with 211 additions and 244 deletions.
17 changes: 7 additions & 10 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ private[spark] class PythonRDD(
if (preservePartitoning) firstParent.partitioner else None
}

override def compute(split: Partition, context: TaskContext): PartitionData[Array[Byte]] = {
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
val startTime = System.currentTimeMillis
val env = SparkEnv.get
val localdir = env.blockManager.diskBlockManager.localDirs.map(
Expand Down Expand Up @@ -180,8 +180,7 @@ private[spark] class PythonRDD(

override def hasNext: Boolean = _nextObj != null
}
// TODO version for ColumnPartitionData
IteratedPartitionData(new InterruptibleIterator(context, stdoutIterator))
new InterruptibleIterator(context, stdoutIterator)
}

val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
Expand Down Expand Up @@ -308,13 +307,11 @@ private class PythonException(msg: String, cause: Exception) extends RuntimeExce
private class PairwiseRDD(prev: RDD[Array[Byte]]) extends RDD[(Long, Array[Byte])](prev) {
override def getPartitions: Array[Partition] = prev.partitions
override val partitioner: Option[Partitioner] = prev.partitioner
override def compute(split: Partition, context: TaskContext): PartitionData[(Long, Array[Byte])] =
// TODO version for ColumnPartitionData
IteratedPartitionData(
prev.iterator(split, context).grouped(2).map {
case Seq(a, b) => (Utils.deserializeLongValue(a), b)
case x => throw new SparkException("PairwiseRDD: unexpected value: " + x)
})
override def compute(split: Partition, context: TaskContext): Iterator[(Long, Array[Byte])] =
prev.iterator(split, context).grouped(2).map {
case Seq(a, b) => (Utils.deserializeLongValue(a), b)
case x => throw new SparkException("PairwiseRDD: unexpected value: " + x)
}
val asJavaPairRDD : JavaPairRDD[Long, Array[Byte]] = JavaPairRDD.fromRDD(this)
}

Expand Down
7 changes: 3 additions & 4 deletions core/src/main/scala/org/apache/spark/api/r/RRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
private var bootTime: Double = _
override def getPartitions: Array[Partition] = parent.partitions

override def compute(partition: Partition, context: TaskContext): PartitionData[U] = {
override def compute(partition: Partition, context: TaskContext): Iterator[U] = {

// Timing start
bootTime = System.currentTimeMillis / 1000.0
Expand Down Expand Up @@ -78,8 +78,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](

try {

// TODO version for ColumnPartitionData
return IteratedPartitionData(new Iterator[U] {
return new Iterator[U] {
def next(): U = {
val obj = _nextObj
if (hasNext) {
Expand All @@ -97,7 +96,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
}
hasMore
}
})
}
} catch {
case e: Exception =>
throw new SparkException("R computation failed with\n " + errThread.getLines())
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class BlockRDD[T: ClassTag](@transient sc: SparkContext, @transient val blockIds
}).toArray
}

override def compute(split: Partition, context: TaskContext): PartitionData[T] = {
override def computePartition(split: Partition, context: TaskContext): PartitionData[T] = {
assertValid()
val blockManager = SparkEnv.get.blockManager
val blockId = split.asInstanceOf[BlockRDDPartition].blockId
Expand Down
9 changes: 3 additions & 6 deletions core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,10 @@ class CartesianRDD[T: ClassTag, U: ClassTag](
(rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)).distinct
}

override def compute(split: Partition, context: TaskContext): PartitionData[(T, U)] = {
override def compute(split: Partition, context: TaskContext): Iterator[(T, U)] = {
val currSplit = split.asInstanceOf[CartesianPartition]
val iter =
for (x <- rdd1.iterator(currSplit.s1, context);
y <- rdd2.iterator(currSplit.s2, context)) yield (x, y)
// TODO version for ColumnPartitionData
IteratedPartitionData(iter)
for (x <- rdd1.iterator(currSplit.s1, context);
y <- rdd2.iterator(currSplit.s2, context)) yield (x, y)
}

override def getDependencies: Seq[Dependency[_]] = List(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ private[spark] abstract class CheckpointRDD[T: ClassTag](@transient sc: SparkCon
// base [[org.apache.spark.rdd.RDD]] class if we do not override the following methods.
// scalastyle:off
protected override def getPartitions: Array[Partition] = ???
override def compute(p: Partition, tc: TaskContext): PartitionData[T] = ???
// scalastyle:on

}
7 changes: 2 additions & 5 deletions core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:

override val partitioner: Some[Partitioner] = Some(part)

override def compute(s: Partition, context: TaskContext):
PartitionData[(K, Array[Iterable[_]])] = {
override def compute(s: Partition, context: TaskContext): Iterator[(K, Array[Iterable[_]])] = {
val sparkConf = SparkEnv.get.conf
val externalSorting = sparkConf.getBoolean("spark.shuffle.spill", true)
val split = s.asInstanceOf[CoGroupPartition]
Expand All @@ -148,7 +147,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
rddIterators += ((it, depNum))
}

val iter = if (!externalSorting) {
if (!externalSorting) {
val map = new AppendOnlyMap[K, CoGroupCombiner]
val update: (Boolean, CoGroupCombiner) => CoGroupCombiner = (hadVal, oldVal) => {
if (hadVal) oldVal else Array.fill(numRdds)(new CoGroup)
Expand Down Expand Up @@ -176,8 +175,6 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
new InterruptibleIterator(context,
map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]])
}
// TODO version for ColumnPartitionData
IteratedPartitionData(iter)
}

private def createExternalMap(numRdds: Int)
Expand Down
10 changes: 4 additions & 6 deletions core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,10 @@ private[spark] class CoalescedRDD[T: ClassTag](
}
}

override def compute(partition: Partition, context: TaskContext): PartitionData[T] = {
// TODO version for ColumnPartitionData
IteratedPartitionData(
partition.asInstanceOf[CoalescedRDDPartition].parents.iterator.flatMap { parentPartition =>
firstParent[T].iterator(parentPartition, context)
})
override def compute(partition: Partition, context: TaskContext): Iterator[T] = {
partition.asInstanceOf[CoalescedRDDPartition].parents.iterator.flatMap { parentPartition =>
firstParent[T].iterator(parentPartition, context)
}
}

override def getDependencies: Seq[Dependency[_]] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ private[spark] class ConversionRDD[T: ClassTag](
override def getPartitions: Array[Partition] =
firstParent[T].partitions

override def compute(split: Partition, context: TaskContext): PartitionData[T] = {
override def computePartition(split: Partition, context: TaskContext): PartitionData[T] = {
val data = firstParent[T].partitionData(split, context)
(data, targetFormat) match {
// Cases where the format is already good
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ private[spark] class EmptyRDD[T: ClassTag](sc: SparkContext) extends RDD[T](sc,

override def getPartitions: Array[Partition] = Array.empty

override def compute(split: Partition, context: TaskContext): PartitionData[T] = {
override def computePartition(split: Partition, context: TaskContext): PartitionData[T] = {
throw new UnsupportedOperationException("empty RDD")
}
}
10 changes: 4 additions & 6 deletions core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ class HadoopRDD[K, V](
array
}

override def compute(theSplit: Partition, context: TaskContext): PartitionData[(K, V)] = {
override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = {
val iter = new NextIterator[(K, V)] {

val split = theSplit.asInstanceOf[HadoopPartition]
Expand Down Expand Up @@ -281,8 +281,7 @@ class HadoopRDD[K, V](
}
}
}
// TODO version for ColumnPartitionData
IteratedPartitionData(new InterruptibleIterator[(K, V)](context, iter))
new InterruptibleIterator[(K, V)](context, iter)
}

/** Maps over a partition, providing the InputSplit that was used as the base of the partition. */
Expand Down Expand Up @@ -375,11 +374,10 @@ private[spark] object HadoopRDD extends Logging {

override def getPartitions: Array[Partition] = firstParent[T].partitions

override def compute(split: Partition, context: TaskContext): PartitionData[U] = {
override def compute(split: Partition, context: TaskContext): Iterator[U] = {
val partition = split.asInstanceOf[HadoopPartition]
val inputSplit = partition.inputSplit.value
// TODO version for ColumnPartitionData
IteratedPartitionData(f(inputSplit, firstParent[T].iterator(split, context)))
f(inputSplit, firstParent[T].iterator(split, context))
}
}

Expand Down
93 changes: 46 additions & 47 deletions core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
import org.apache.spark.api.java.function.{Function => JFunction}
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.util.NextIterator
import org.apache.spark.{Logging, Partition, SparkContext, TaskContext, PartitionData, IteratedPartitionData}
import org.apache.spark.{Logging, Partition, SparkContext, TaskContext}

private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition {
override def index: Int = idx
Expand Down Expand Up @@ -71,60 +71,59 @@ class JdbcRDD[T: ClassTag](
}).toArray
}

override def compute(thePart: Partition, context: TaskContext): PartitionData[T] =
// TODO version for ColumnPartitionData
IteratedPartitionData(new NextIterator[T] {
context.addTaskCompletionListener{ context => closeIfNeeded() }
val part = thePart.asInstanceOf[JdbcPartition]
val conn = getConnection()
val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)

// setFetchSize(Integer.MIN_VALUE) is a mysql driver specific way to force streaming results,
// rather than pulling entire resultset into memory.
// see http://dev.mysql.com/doc/refman/5.0/en/connector-j-reference-implementation-notes.html
if (conn.getMetaData.getURL.matches("jdbc:mysql:.*")) {
stmt.setFetchSize(Integer.MIN_VALUE)
logInfo("statement fetch size set to: " + stmt.getFetchSize + " to force MySQL streaming ")
}
override def compute(thePart: Partition, context: TaskContext): Iterator[T] = new NextIterator[T]
{
context.addTaskCompletionListener{ context => closeIfNeeded() }
val part = thePart.asInstanceOf[JdbcPartition]
val conn = getConnection()
val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)

// setFetchSize(Integer.MIN_VALUE) is a mysql driver specific way to force streaming results,
// rather than pulling entire resultset into memory.
// see http://dev.mysql.com/doc/refman/5.0/en/connector-j-reference-implementation-notes.html
if (conn.getMetaData.getURL.matches("jdbc:mysql:.*")) {
stmt.setFetchSize(Integer.MIN_VALUE)
logInfo("statement fetch size set to: " + stmt.getFetchSize + " to force MySQL streaming ")
}

stmt.setLong(1, part.lower)
stmt.setLong(2, part.upper)
val rs = stmt.executeQuery()
stmt.setLong(1, part.lower)
stmt.setLong(2, part.upper)
val rs = stmt.executeQuery()

override def getNext(): T = {
if (rs.next()) {
mapRow(rs)
} else {
finished = true
null.asInstanceOf[T]
}
override def getNext(): T = {
if (rs.next()) {
mapRow(rs)
} else {
finished = true
null.asInstanceOf[T]
}
}

override def close() {
try {
if (null != rs) {
rs.close()
}
} catch {
case e: Exception => logWarning("Exception closing resultset", e)
override def close() {
try {
if (null != rs) {
rs.close()
}
try {
if (null != stmt) {
stmt.close()
}
} catch {
case e: Exception => logWarning("Exception closing statement", e)
} catch {
case e: Exception => logWarning("Exception closing resultset", e)
}
try {
if (null != stmt) {
stmt.close()
}
try {
if (null != conn) {
conn.close()
}
logInfo("closed connection")
} catch {
case e: Exception => logWarning("Exception closing connection", e)
} catch {
case e: Exception => logWarning("Exception closing statement", e)
}
try {
if (null != conn) {
conn.close()
}
logInfo("closed connection")
} catch {
case e: Exception => logWarning("Exception closing connection", e)
}
})
}
}
}

object JdbcRDD {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ private[spark] class LocalCheckpointRDD[T: ClassTag](
* is expected to be fully cached and so all partitions should already be computed and
* available in the block storage.
*/
override def compute(partition: Partition, context: TaskContext): PartitionData[T] = {
override def computePartition(partition: Partition, context: TaskContext): PartitionData[T] = {
throw new SparkException(
s"Checkpoint block ${RDDBlockId(rddId, partition.index)} not found! Either the executor " +
s"that originally checkpointed this partition is no longer alive, or the original RDD is " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.rdd

import scala.reflect.ClassTag

import org.apache.spark.{Partition, TaskContext, PartitionData, IteratedPartitionData}
import org.apache.spark.{Partition, TaskContext}

private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](
prev: RDD[T],
Expand All @@ -31,7 +31,6 @@ private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](

override def getPartitions: Array[Partition] = firstParent[T].partitions

override def compute(split: Partition, context: TaskContext): PartitionData[U] =
// TODO version for ColumnPartitionData
IteratedPartitionData(f(context, split.index, firstParent[T].iterator(split, context)))
override def compute(split: Partition, context: TaskContext): Iterator[U] =
f(context, split.index, firstParent[T].iterator(split, context))
}
9 changes: 4 additions & 5 deletions core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class NewHadoopRDD[K, V](
result
}

override def compute(theSplit: Partition, context: TaskContext): PartitionData[(K, V)] = {
override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = {
val iter = new Iterator[(K, V)] {
val split = theSplit.asInstanceOf[NewHadoopPartition]
logInfo("Input split: " + split.serializableHadoopSplit)
Expand Down Expand Up @@ -193,7 +193,7 @@ class NewHadoopRDD[K, V](
}
}
}
IteratedPartitionData(new InterruptibleIterator(context, iter))
new InterruptibleIterator(context, iter)
}

/** Maps over a partition, providing the InputSplit that was used as the base of the partition. */
Expand Down Expand Up @@ -249,11 +249,10 @@ private[spark] object NewHadoopRDD {

override def getPartitions: Array[Partition] = firstParent[T].partitions

override def compute(split: Partition, context: TaskContext): PartitionData[U] = {
override def compute(split: Partition, context: TaskContext): Iterator[U] = {
val partition = split.asInstanceOf[NewHadoopPartition]
val inputSplit = partition.serializableHadoopSplit.value
// TODO version for ColumnPartitionData
IteratedPartitionData(f(inputSplit, firstParent[T].iterator(split, context)))
f(inputSplit, firstParent[T].iterator(split, context))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,8 @@ private[spark] class ParallelCollectionRDD[T: ClassTag](
slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray
}

override def compute(s: Partition, context: TaskContext): IteratedPartitionData[T] = {
// TODO version for ColumnPartitionData
IteratedPartitionData(
new InterruptibleIterator(context, s.asInstanceOf[ParallelCollectionPartition[T]].iterator))
override def compute(s: Partition, context: TaskContext): Iterator[T] = {
new InterruptibleIterator(context, s.asInstanceOf[ParallelCollectionPartition[T]].iterator)
}

override def getPreferredLocations(s: Partition): Seq[String] = {
Expand Down
Loading

0 comments on commit 7401505

Please sign in to comment.