Skip to content

Commit

Permalink
[SPARK-8127] [STREAMING] [KAFKA] KafkaRDD optimize count() take() isE…
Browse files Browse the repository at this point in the history
…mpty()

…ed KafkaRDD methods.  Possible fix for [SPARK-7122], but probably a worthwhile optimization regardless.

Author: cody koeninger <[email protected]>

Closes apache#6632 from koeninger/kafka-rdd-count and squashes the following commits:

321340d [cody koeninger] [SPARK-8127][Streaming][Kafka] additional test of ordering of take()
5a05d0f [cody koeninger] [SPARK-8127][Streaming][Kafka] additional test of isEmpty
f68bd32 [cody koeninger] [Streaming][Kafka][SPARK-8127] code cleanup
9555b73 [cody koeninger] Merge branch 'master' into kafka-rdd-count
253031d [cody koeninger] [Streaming][Kafka][SPARK-8127] mima exclusion for change to private method
8974b9e [cody koeninger] [Streaming][Kafka][SPARK-8127] check offset ranges before constructing KafkaRDD
c3768c5 [cody koeninger] [Streaming][Kafka] Take advantage of offset range info for size-related KafkaRDD methods.  Possible fix for [SPARK-7122], but probably a worthwhile optimization regardless.
  • Loading branch information
koeninger authored and tdas committed Jun 20, 2015
1 parent bec40e5 commit 1b6fe9b
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,7 @@ class DirectKafkaInputDStream[
context.sparkContext, kafkaParams, currentOffsets, untilOffsets, messageHandler)

// Report the record number of this batch interval to InputInfoTracker.
val numRecords = rdd.offsetRanges.map(r => r.untilOffset - r.fromOffset).sum
val inputInfo = InputInfo(id, numRecords)
val inputInfo = InputInfo(id, rdd.count)
ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo)

currentOffsets = untilOffsets.map(kv => kv._1 -> kv._2.offset)
Expand Down Expand Up @@ -153,10 +152,7 @@ class DirectKafkaInputDStream[
override def restore() {
// this is assuming that the topics don't change during execution, which is true currently
val topics = fromOffsets.keySet
val leaders = kc.findLeaders(topics).fold(
errs => throw new SparkException(errs.mkString("\n")),
ok => ok
)
val leaders = KafkaCluster.checkErrors(kc.findLeaders(topics))

batchForTime.toSeq.sortBy(_._1)(Time.ordering).foreach { case (t, b) =>
logInfo(s"Restoring KafkaRDD for time $t ${b.mkString("[", ", ", "]")}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,14 @@ private[spark]
object KafkaCluster {
type Err = ArrayBuffer[Throwable]

/** If the result is right, return it, otherwise throw SparkException */
def checkErrors[T](result: Either[Err, T]): T = {
result.fold(
errs => throw new SparkException(errs.mkString("\n")),
ok => ok
)
}

private[spark]
case class LeaderOffset(host: String, port: Int, offset: Long)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@

package org.apache.spark.streaming.kafka

import scala.collection.mutable.ArrayBuffer
import scala.reflect.{classTag, ClassTag}

import org.apache.spark.{Logging, Partition, SparkContext, SparkException, TaskContext}
import org.apache.spark.partial.{PartialResult, BoundedDouble}
import org.apache.spark.rdd.RDD
import org.apache.spark.util.NextIterator

Expand Down Expand Up @@ -60,6 +62,48 @@ class KafkaRDD[
}.toArray
}

override def count(): Long = offsetRanges.map(_.count).sum

override def countApprox(
timeout: Long,
confidence: Double = 0.95
): PartialResult[BoundedDouble] = {
val c = count
new PartialResult(new BoundedDouble(c, 1.0, c, c), true)
}

override def isEmpty(): Boolean = count == 0L

override def take(num: Int): Array[R] = {
val nonEmptyPartitions = this.partitions
.map(_.asInstanceOf[KafkaRDDPartition])
.filter(_.count > 0)

if (num < 1 || nonEmptyPartitions.size < 1) {
return new Array[R](0)
}

// Determine in advance how many messages need to be taken from each partition
val parts = nonEmptyPartitions.foldLeft(Map[Int, Int]()) { (result, part) =>
val remain = num - result.values.sum
if (remain > 0) {
val taken = Math.min(remain, part.count)
result + (part.index -> taken.toInt)
} else {
result
}
}

val buf = new ArrayBuffer[R]
val res = context.runJob(
this,
(tc: TaskContext, it: Iterator[R]) => it.take(parts(tc.partitionId)).toArray,
parts.keys.toArray,
allowLocal = true)
res.foreach(buf ++= _)
buf.toArray
}

override def getPreferredLocations(thePart: Partition): Seq[String] = {
val part = thePart.asInstanceOf[KafkaRDDPartition]
// TODO is additional hostname resolution necessary here
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,7 @@ class KafkaRDDPartition(
val untilOffset: Long,
val host: String,
val port: Int
) extends Partition
) extends Partition {
/** Number of messages this partition refers to */
def count(): Long = untilOffset - fromOffset
}
Original file line number Diff line number Diff line change
Expand Up @@ -158,15 +158,31 @@ object KafkaUtils {

/** get leaders for the given offset ranges, or throw an exception */
private def leadersForRanges(
kafkaParams: Map[String, String],
kc: KafkaCluster,
offsetRanges: Array[OffsetRange]): Map[TopicAndPartition, (String, Int)] = {
val kc = new KafkaCluster(kafkaParams)
val topics = offsetRanges.map(o => TopicAndPartition(o.topic, o.partition)).toSet
val leaders = kc.findLeaders(topics).fold(
errs => throw new SparkException(errs.mkString("\n")),
ok => ok
)
leaders
val leaders = kc.findLeaders(topics)
KafkaCluster.checkErrors(leaders)
}

/** Make sure offsets are available in kafka, or throw an exception */
private def checkOffsets(
kc: KafkaCluster,
offsetRanges: Array[OffsetRange]): Unit = {
val topics = offsetRanges.map(_.topicAndPartition).toSet
val result = for {
low <- kc.getEarliestLeaderOffsets(topics).right
high <- kc.getLatestLeaderOffsets(topics).right
} yield {
offsetRanges.filterNot { o =>
low(o.topicAndPartition).offset <= o.fromOffset &&
o.untilOffset <= high(o.topicAndPartition).offset
}
}
val badRanges = KafkaCluster.checkErrors(result)
if (!badRanges.isEmpty) {
throw new SparkException("Offsets not available on leader: " + badRanges.mkString(","))
}
}

/**
Expand All @@ -191,7 +207,9 @@ object KafkaUtils {
offsetRanges: Array[OffsetRange]
): RDD[(K, V)] = sc.withScope {
val messageHandler = (mmd: MessageAndMetadata[K, V]) => (mmd.key, mmd.message)
val leaders = leadersForRanges(kafkaParams, offsetRanges)
val kc = new KafkaCluster(kafkaParams)
val leaders = leadersForRanges(kc, offsetRanges)
checkOffsets(kc, offsetRanges)
new KafkaRDD[K, V, KD, VD, (K, V)](sc, kafkaParams, offsetRanges, leaders, messageHandler)
}

Expand Down Expand Up @@ -225,15 +243,17 @@ object KafkaUtils {
leaders: Map[TopicAndPartition, Broker],
messageHandler: MessageAndMetadata[K, V] => R
): RDD[R] = sc.withScope {
val kc = new KafkaCluster(kafkaParams)
val leaderMap = if (leaders.isEmpty) {
leadersForRanges(kafkaParams, offsetRanges)
leadersForRanges(kc, offsetRanges)
} else {
// This could be avoided by refactoring KafkaRDD.leaders and KafkaCluster to use Broker
leaders.map {
case (tp: TopicAndPartition, Broker(host, port)) => (tp, (host, port))
}.toMap
}
val cleanedHandler = sc.clean(messageHandler)
checkOffsets(kc, offsetRanges)
new KafkaRDD[K, V, KD, VD, R](sc, kafkaParams, offsetRanges, leaderMap, cleanedHandler)
}

Expand Down Expand Up @@ -399,7 +419,7 @@ object KafkaUtils {
val kc = new KafkaCluster(kafkaParams)
val reset = kafkaParams.get("auto.offset.reset").map(_.toLowerCase)

(for {
val result = for {
topicPartitions <- kc.getPartitions(topics).right
leaderOffsets <- (if (reset == Some("smallest")) {
kc.getEarliestLeaderOffsets(topicPartitions)
Expand All @@ -412,10 +432,8 @@ object KafkaUtils {
}
new DirectKafkaInputDStream[K, V, KD, VD, (K, V)](
ssc, kafkaParams, fromOffsets, messageHandler)
}).fold(
errs => throw new SparkException(errs.mkString("\n")),
ok => ok
)
}
KafkaCluster.checkErrors(result)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ final class OffsetRange private(
val untilOffset: Long) extends Serializable {
import OffsetRange.OffsetRangeTuple

/** Kafka TopicAndPartition object, for convenience */
def topicAndPartition(): TopicAndPartition = TopicAndPartition(topic, partition)

/** Number of messages this OffsetRange refers to */
def count(): Long = untilOffset - fromOffset

override def equals(obj: Any): Boolean = obj match {
case that: OffsetRange =>
this.topic == that.topic &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
test("basic usage") {
val topic = s"topicbasic-${Random.nextInt}"
kafkaTestUtils.createTopic(topic)
val messages = Set("the", "quick", "brown", "fox")
kafkaTestUtils.sendMessages(topic, messages.toArray)
val messages = Array("the", "quick", "brown", "fox")
kafkaTestUtils.sendMessages(topic, messages)

val kafkaParams = Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress,
"group.id" -> s"test-consumer-${Random.nextInt}")
Expand All @@ -67,7 +67,27 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
sc, kafkaParams, offsetRanges)

val received = rdd.map(_._2).collect.toSet
assert(received === messages)
assert(received === messages.toSet)

// size-related method optimizations return sane results
assert(rdd.count === messages.size)
assert(rdd.countApprox(0).getFinalValue.mean === messages.size)
assert(!rdd.isEmpty)
assert(rdd.take(1).size === 1)
assert(rdd.take(1).head._2 === messages.head)
assert(rdd.take(messages.size + 10).size === messages.size)

val emptyRdd = KafkaUtils.createRDD[String, String, StringDecoder, StringDecoder](
sc, kafkaParams, Array(OffsetRange(topic, 0, 0, 0)))

assert(emptyRdd.isEmpty)

// invalid offset ranges throw exceptions
val badRanges = Array(OffsetRange(topic, 0, 0, messages.size + 1))
intercept[SparkException] {
KafkaUtils.createRDD[String, String, StringDecoder, StringDecoder](
sc, kafkaParams, badRanges)
}
}

test("iterator boundary conditions") {
Expand Down
3 changes: 3 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ object MimaExcludes {
// JavaRDDLike is not meant to be extended by user programs
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.api.java.JavaRDDLike.partitioner"),
// Modification of private static method
ProblemFilters.exclude[IncompatibleMethTypeProblem](
"org.apache.spark.streaming.kafka.KafkaUtils.org$apache$spark$streaming$kafka$KafkaUtils$$leadersForRanges"),
// Mima false positive (was a private[spark] class)
ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.util.collection.PairIterator"),
Expand Down

0 comments on commit 1b6fe9b

Please sign in to comment.