diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala index 060c2f23eded8..876456c964770 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -120,8 +120,7 @@ class DirectKafkaInputDStream[ context.sparkContext, kafkaParams, currentOffsets, untilOffsets, messageHandler) // Report the record number of this batch interval to InputInfoTracker. - val numRecords = rdd.offsetRanges.map(r => r.untilOffset - r.fromOffset).sum - val inputInfo = InputInfo(id, numRecords) + val inputInfo = InputInfo(id, rdd.count) ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) currentOffsets = untilOffsets.map(kv => kv._1 -> kv._2.offset) @@ -153,10 +152,7 @@ class DirectKafkaInputDStream[ override def restore() { // this is assuming that the topics don't change during execution, which is true currently val topics = fromOffsets.keySet - val leaders = kc.findLeaders(topics).fold( - errs => throw new SparkException(errs.mkString("\n")), - ok => ok - ) + val leaders = KafkaCluster.checkErrors(kc.findLeaders(topics)) batchForTime.toSeq.sortBy(_._1)(Time.ordering).foreach { case (t, b) => logInfo(s"Restoring KafkaRDD for time $t ${b.mkString("[", ", ", "]")}") diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala index 65d51d87f8486..3e6b937af57b0 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala @@ -360,6 +360,14 @@ private[spark] object KafkaCluster { type Err = ArrayBuffer[Throwable] + /** If the result is right, return it, otherwise throw SparkException */ + def checkErrors[T](result: Either[Err, T]): T = { + result.fold( + errs => throw new SparkException(errs.mkString("\n")), + ok => ok + ) + } + private[spark] case class LeaderOffset(host: String, port: Int, offset: Long) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala index a1b4a12e5d6a0..c5cd2154772ac 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala @@ -17,9 +17,11 @@ package org.apache.spark.streaming.kafka +import scala.collection.mutable.ArrayBuffer import scala.reflect.{classTag, ClassTag} import org.apache.spark.{Logging, Partition, SparkContext, SparkException, TaskContext} +import org.apache.spark.partial.{PartialResult, BoundedDouble} import org.apache.spark.rdd.RDD import org.apache.spark.util.NextIterator @@ -60,6 +62,48 @@ class KafkaRDD[ }.toArray } + override def count(): Long = offsetRanges.map(_.count).sum + + override def countApprox( + timeout: Long, + confidence: Double = 0.95 + ): PartialResult[BoundedDouble] = { + val c = count + new PartialResult(new BoundedDouble(c, 1.0, c, c), true) + } + + override def isEmpty(): Boolean = count == 0L + + override def take(num: Int): Array[R] = { + val nonEmptyPartitions = this.partitions + .map(_.asInstanceOf[KafkaRDDPartition]) + .filter(_.count > 0) + + if (num < 1 || nonEmptyPartitions.size < 1) { + return new Array[R](0) + } + + // Determine in advance how many messages need to be taken from each partition + val parts = nonEmptyPartitions.foldLeft(Map[Int, Int]()) { (result, part) => + val remain = num - result.values.sum + if (remain > 0) { + val taken = Math.min(remain, part.count) + result + (part.index -> taken.toInt) + } else { + result + } + } + + val buf = new ArrayBuffer[R] + val res = context.runJob( + this, + (tc: TaskContext, it: Iterator[R]) => it.take(parts(tc.partitionId)).toArray, + parts.keys.toArray, + allowLocal = true) + res.foreach(buf ++= _) + buf.toArray + } + override def getPreferredLocations(thePart: Partition): Seq[String] = { val part = thePart.asInstanceOf[KafkaRDDPartition] // TODO is additional hostname resolution necessary here diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala index a842a6f17766f..a660d2a00c35d 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala @@ -35,4 +35,7 @@ class KafkaRDDPartition( val untilOffset: Long, val host: String, val port: Int -) extends Partition +) extends Partition { + /** Number of messages this partition refers to */ + def count(): Long = untilOffset - fromOffset +} diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index 0b8a391a2c569..0e33362d34acd 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -158,15 +158,31 @@ object KafkaUtils { /** get leaders for the given offset ranges, or throw an exception */ private def leadersForRanges( - kafkaParams: Map[String, String], + kc: KafkaCluster, offsetRanges: Array[OffsetRange]): Map[TopicAndPartition, (String, Int)] = { - val kc = new KafkaCluster(kafkaParams) val topics = offsetRanges.map(o => TopicAndPartition(o.topic, o.partition)).toSet - val leaders = kc.findLeaders(topics).fold( - errs => throw new SparkException(errs.mkString("\n")), - ok => ok - ) - leaders + val leaders = kc.findLeaders(topics) + KafkaCluster.checkErrors(leaders) + } + + /** Make sure offsets are available in kafka, or throw an exception */ + private def checkOffsets( + kc: KafkaCluster, + offsetRanges: Array[OffsetRange]): Unit = { + val topics = offsetRanges.map(_.topicAndPartition).toSet + val result = for { + low <- kc.getEarliestLeaderOffsets(topics).right + high <- kc.getLatestLeaderOffsets(topics).right + } yield { + offsetRanges.filterNot { o => + low(o.topicAndPartition).offset <= o.fromOffset && + o.untilOffset <= high(o.topicAndPartition).offset + } + } + val badRanges = KafkaCluster.checkErrors(result) + if (!badRanges.isEmpty) { + throw new SparkException("Offsets not available on leader: " + badRanges.mkString(",")) + } } /** @@ -191,7 +207,9 @@ object KafkaUtils { offsetRanges: Array[OffsetRange] ): RDD[(K, V)] = sc.withScope { val messageHandler = (mmd: MessageAndMetadata[K, V]) => (mmd.key, mmd.message) - val leaders = leadersForRanges(kafkaParams, offsetRanges) + val kc = new KafkaCluster(kafkaParams) + val leaders = leadersForRanges(kc, offsetRanges) + checkOffsets(kc, offsetRanges) new KafkaRDD[K, V, KD, VD, (K, V)](sc, kafkaParams, offsetRanges, leaders, messageHandler) } @@ -225,8 +243,9 @@ object KafkaUtils { leaders: Map[TopicAndPartition, Broker], messageHandler: MessageAndMetadata[K, V] => R ): RDD[R] = sc.withScope { + val kc = new KafkaCluster(kafkaParams) val leaderMap = if (leaders.isEmpty) { - leadersForRanges(kafkaParams, offsetRanges) + leadersForRanges(kc, offsetRanges) } else { // This could be avoided by refactoring KafkaRDD.leaders and KafkaCluster to use Broker leaders.map { @@ -234,6 +253,7 @@ object KafkaUtils { }.toMap } val cleanedHandler = sc.clean(messageHandler) + checkOffsets(kc, offsetRanges) new KafkaRDD[K, V, KD, VD, R](sc, kafkaParams, offsetRanges, leaderMap, cleanedHandler) } @@ -399,7 +419,7 @@ object KafkaUtils { val kc = new KafkaCluster(kafkaParams) val reset = kafkaParams.get("auto.offset.reset").map(_.toLowerCase) - (for { + val result = for { topicPartitions <- kc.getPartitions(topics).right leaderOffsets <- (if (reset == Some("smallest")) { kc.getEarliestLeaderOffsets(topicPartitions) @@ -412,10 +432,8 @@ object KafkaUtils { } new DirectKafkaInputDStream[K, V, KD, VD, (K, V)]( ssc, kafkaParams, fromOffsets, messageHandler) - }).fold( - errs => throw new SparkException(errs.mkString("\n")), - ok => ok - ) + } + KafkaCluster.checkErrors(result) } /** diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala index 9c3dfeb8f5928..2675042666304 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala @@ -55,6 +55,12 @@ final class OffsetRange private( val untilOffset: Long) extends Serializable { import OffsetRange.OffsetRangeTuple + /** Kafka TopicAndPartition object, for convenience */ + def topicAndPartition(): TopicAndPartition = TopicAndPartition(topic, partition) + + /** Number of messages this OffsetRange refers to */ + def count(): Long = untilOffset - fromOffset + override def equals(obj: Any): Boolean = obj match { case that: OffsetRange => this.topic == that.topic && diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala index d5baf5fd89994..f52a738afd65b 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala @@ -55,8 +55,8 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { test("basic usage") { val topic = s"topicbasic-${Random.nextInt}" kafkaTestUtils.createTopic(topic) - val messages = Set("the", "quick", "brown", "fox") - kafkaTestUtils.sendMessages(topic, messages.toArray) + val messages = Array("the", "quick", "brown", "fox") + kafkaTestUtils.sendMessages(topic, messages) val kafkaParams = Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress, "group.id" -> s"test-consumer-${Random.nextInt}") @@ -67,7 +67,27 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { sc, kafkaParams, offsetRanges) val received = rdd.map(_._2).collect.toSet - assert(received === messages) + assert(received === messages.toSet) + + // size-related method optimizations return sane results + assert(rdd.count === messages.size) + assert(rdd.countApprox(0).getFinalValue.mean === messages.size) + assert(!rdd.isEmpty) + assert(rdd.take(1).size === 1) + assert(rdd.take(1).head._2 === messages.head) + assert(rdd.take(messages.size + 10).size === messages.size) + + val emptyRdd = KafkaUtils.createRDD[String, String, StringDecoder, StringDecoder]( + sc, kafkaParams, Array(OffsetRange(topic, 0, 0, 0))) + + assert(emptyRdd.isEmpty) + + // invalid offset ranges throw exceptions + val badRanges = Array(OffsetRange(topic, 0, 0, messages.size + 1)) + intercept[SparkException] { + KafkaUtils.createRDD[String, String, StringDecoder, StringDecoder]( + sc, kafkaParams, badRanges) + } } test("iterator boundary conditions") { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 8a93ca2999510..015d0296dd369 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -44,6 +44,9 @@ object MimaExcludes { // JavaRDDLike is not meant to be extended by user programs ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.api.java.JavaRDDLike.partitioner"), + // Modification of private static method + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.streaming.kafka.KafkaUtils.org$apache$spark$streaming$kafka$KafkaUtils$$leadersForRanges"), // Mima false positive (was a private[spark] class) ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.util.collection.PairIterator"),