Skip to content

Commit

Permalink
[SPARK-11060] [STREAMING] Fix some potential NPE in DStream transform…
Browse files Browse the repository at this point in the history
…ation

This patch fixes:

1. Guard out against NPEs in `TransformedDStream` when parent DStream returns None instead of empty RDD.
2. Verify some input streams which will potentially return None.
3. Add unit test to verify the behavior when input stream returns None.

cc tdas , please help to review, thanks a lot :).

Author: jerryshao <[email protected]>

Closes #9070 from jerryshao/SPARK-11060.
  • Loading branch information
jerryshao authored and srowen committed Oct 16, 2015
1 parent eb0b4d6 commit 43f5d1f
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,20 @@

package org.apache.spark.streaming.dstream

import scala.reflect.ClassTag

import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.{Time, StreamingContext}
import scala.reflect.ClassTag

/**
* An input stream that always returns the same RDD on each timestep. Useful for testing.
*/
class ConstantInputDStream[T: ClassTag](ssc_ : StreamingContext, rdd: RDD[T])
extends InputDStream[T](ssc_) {

require(rdd != null,
"parameter rdd null is illegal, which will lead to NPE in the following transformation")

override def start() {}

override def stop() {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class QueueInputDStream[T: ClassTag](
} else if (defaultRDD != null) {
Some(defaultRDD)
} else {
None
Some(ssc.sparkContext.emptyRDD)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.streaming.dstream
import scala.reflect.ClassTag

import org.apache.spark.SparkException
import org.apache.spark.rdd.{PairRDDFunctions, RDD}
import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.{Duration, Time}

private[streaming]
Expand All @@ -39,7 +39,10 @@ class TransformedDStream[U: ClassTag] (
override def slideDuration: Duration = parents.head.slideDuration

override def compute(validTime: Time): Option[RDD[U]] = {
val parentRDDs = parents.map(_.getOrCompute(validTime).orNull).toSeq
val parentRDDs = parents.map { parent => parent.getOrCompute(validTime).getOrElse(
// Guard out against parent DStream that return None instead of Some(rdd) to avoid NPE
throw new SparkException(s"Couldn't generate RDD from parent at time $validTime"))
}
val transformedRDD = transformFunc(parentRDDs, validTime)
if (transformedRDD == null) {
throw new SparkException("Transform function must not return null. " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@

package org.apache.spark.streaming.dstream

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

import org.apache.spark.SparkException
import org.apache.spark.streaming.{Duration, Time}
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.UnionRDD

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

private[streaming]
class UnionDStream[T: ClassTag](parents: Array[DStream[T]])
extends DStream[T](parents.head.ssc) {
Expand All @@ -41,8 +42,8 @@ class UnionDStream[T: ClassTag](parents: Array[DStream[T]])
val rdds = new ArrayBuffer[RDD[T]]()
parents.map(_.getOrCompute(validTime)).foreach {
case Some(rdd) => rdds += rdd
case None => throw new Exception("Could not generate RDD from a parent for unifying at time "
+ validTime)
case None => throw new SparkException("Could not generate RDD from a parent for unifying at" +
s" time $validTime")
}
if (rdds.size > 0) {
Some(new UnionRDD(ssc.sc, rdds))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,20 @@ class BasicOperationsSuite extends TestSuiteBase {
)
}

test("union with input stream return None") {
val input = Seq(1 to 4, 101 to 104, 201 to 204, null)
val output = Seq(1 to 8, 101 to 108, 201 to 208)
intercept[SparkException] {
testOperation(
input,
(s: DStream[Int]) => s.union(s.map(_ + 4)),
output,
input.length,
false
)
}
}

test("StreamingContext.union") {
val input = Seq(1 to 4, 101 to 104, 201 to 204)
val output = Seq(1 to 12, 101 to 112, 201 to 212)
Expand Down Expand Up @@ -224,6 +238,19 @@ class BasicOperationsSuite extends TestSuiteBase {
}
}

test("transform with input stream return None") {
val input = Seq(1 to 4, 5 to 8, null)
intercept[SparkException] {
testOperation(
input,
(r: DStream[Int]) => r.transform(rdd => rdd.map(_.toString)),
input.filterNot(_ == null).map(_.map(_.toString)),
input.length,
false
)
}
}

test("transformWith") {
val inputData1 = Seq( Seq("a", "b"), Seq("a", ""), Seq(""), Seq() )
val inputData2 = Seq( Seq("a", "b"), Seq("b", ""), Seq(), Seq("") )
Expand All @@ -244,6 +271,27 @@ class BasicOperationsSuite extends TestSuiteBase {
testOperation(inputData1, inputData2, operation, outputData, true)
}

test("transformWith with input stream return None") {
val inputData1 = Seq( Seq("a", "b"), Seq("a", ""), Seq(""), null )
val inputData2 = Seq( Seq("a", "b"), Seq("b", ""), Seq(), null )
val outputData = Seq(
Seq("a", "b", "a", "b"),
Seq("a", "b", "", ""),
Seq("")
)

val operation = (s1: DStream[String], s2: DStream[String]) => {
s1.transformWith( // RDD.join in transform
s2,
(rdd1: RDD[String], rdd2: RDD[String]) => rdd1.union(rdd2)
)
}

intercept[SparkException] {
testOperation(inputData1, inputData2, operation, outputData, inputData1.length, true)
}
}

test("StreamingContext.transform") {
val input = Seq(1 to 4, 101 to 104, 201 to 204)
val output = Seq(1 to 12, 101 to 112, 201 to 212)
Expand All @@ -260,6 +308,24 @@ class BasicOperationsSuite extends TestSuiteBase {
testOperation(input, operation, output)
}

test("StreamingContext.transform with input stream return None") {
val input = Seq(1 to 4, 101 to 104, 201 to 204, null)
val output = Seq(1 to 12, 101 to 112, 201 to 212)

// transform over 3 DStreams by doing union of the 3 RDDs
val operation = (s: DStream[Int]) => {
s.context.transform(
Seq(s, s.map(_ + 4), s.map(_ + 8)), // 3 DStreams
(rdds: Seq[RDD[_]], time: Time) =>
rdds.head.context.union(rdds.map(_.asInstanceOf[RDD[Int]])) // union of RDDs
)
}

intercept[SparkException] {
testOperation(input, operation, output, input.length, false)
}
}

test("cogroup") {
val inputData1 = Seq( Seq("a", "a", "b"), Seq("a", ""), Seq(""), Seq() )
val inputData2 = Seq( Seq("a", "a", "b"), Seq("b", ""), Seq(), Seq() )
Expand Down

0 comments on commit 43f5d1f

Please sign in to comment.