Skip to content

Commit

Permalink
[SPARK-10731] [SQL] Delegate to Scala's DataFrame.take implementation…
Browse files Browse the repository at this point in the history
… in Python DataFrame.

Python DataFrame.head/take now requires scanning all the partitions. This pull request changes them to delegate the actual implementation to Scala DataFrame (by calling DataFrame.take).

This is more of a hack for fixing this issue in 1.5.1. A more proper fix is to change executeCollect and executeTake to return InternalRow rather than Row, and thus eliminate the extra round-trip conversion.

Author: Reynold Xin <[email protected]>

Closes #8876 from rxin/SPARK-10731.
  • Loading branch information
rxin committed Sep 23, 2015
1 parent 067afb4 commit 9952217
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ private[spark] object PythonRDD extends Logging {
*
* The thread will terminate after all the data are sent or any exceptions happen.
*/
private def serveIterator[T](items: Iterator[T], threadName: String): Int = {
def serveIterator[T](items: Iterator[T], threadName: String): Int = {
val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
// Close the socket if no connection in 3 seconds
serverSocket.setSoTimeout(3000)
Expand Down
5 changes: 4 additions & 1 deletion python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,10 @@ def take(self, num):
>>> df.take(2)
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
"""
return self.limit(num).collect()
with SCCallSiteSync(self._sc) as css:
port = self._sc._jvm.org.apache.spark.sql.execution.EvaluatePython.takeAndServe(
self._jdf, num)
return list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))

@ignore_unicode_prefix
@since(1.3)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.python.{PythonRunner, PythonBroadcast, PythonRDD, SerDeUtil}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
Expand Down Expand Up @@ -118,6 +119,17 @@ object EvaluatePython {
def apply(udf: PythonUDF, child: LogicalPlan): EvaluatePython =
new EvaluatePython(udf, child, AttributeReference("pythonUDF", udf.dataType)())

def takeAndServe(df: DataFrame, n: Int): Int = {
registerPicklers()
// This is an annoying hack - we should refactor the code so executeCollect and executeTake
// returns InternalRow rather than Row.
val converter = CatalystTypeConverters.createToCatalystConverter(df.schema)
val iter = new SerDeUtil.AutoBatchedPickler(df.take(n).iterator.map { row =>
EvaluatePython.toJava(converter(row).asInstanceOf[InternalRow], df.schema)
})
PythonRDD.serveIterator(iter, s"serve-DataFrame")
}

/**
* Helper for converting from Catalyst type to java type suitable for Pyrolite.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,22 +39,20 @@ private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] {

override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT"

override def serialize(obj: Any): Seq[Double] = {
override def serialize(obj: Any): GenericArrayData = {
obj match {
case p: ExamplePoint =>
Seq(p.x, p.y)
val output = new Array[Any](2)
output(0) = p.x
output(1) = p.y
new GenericArrayData(output)
}
}

override def deserialize(datum: Any): ExamplePoint = {
datum match {
case values: Seq[_] =>
val xy = values.asInstanceOf[Seq[Double]]
assert(xy.length == 2)
new ExamplePoint(xy(0), xy(1))
case values: util.ArrayList[_] =>
val xy = values.asInstanceOf[util.ArrayList[Double]].asScala
new ExamplePoint(xy(0), xy(1))
case values: ArrayData =>
new ExamplePoint(values.getDouble(0), values.getDouble(1))
}
}

Expand Down

0 comments on commit 9952217

Please sign in to comment.