Skip to content

Commit

Permalink
[SPARK-2674] [SQL] [PySpark] support datetime type for SchemaRDD
Browse files Browse the repository at this point in the history
Datetime and time in Python will be converted into java.util.Calendar after serialization, it will be converted into java.sql.Timestamp during inferSchema().

In javaToPython(), Timestamp will be converted into Calendar, then be converted into datetime in Python after pickling.

Author: Davies Liu <[email protected]>

Closes apache#1601 from davies/date and squashes the following commits:

f0599b0 [Davies Liu] remove tests for sets and tuple in sql, fix list of list
c9d607a [Davies Liu] convert datetype for runtime
709d40d [Davies Liu] remove brackets
96db384 [Davies Liu] support datetime type for SchemaRDD
  • Loading branch information
davies authored and marmbrus committed Jul 29, 2014
1 parent e364348 commit f0d880e
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -550,11 +550,11 @@ private[spark] object PythonRDD extends Logging {
def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = {
pyRDD.rdd.mapPartitions { iter =>
val unpickle = new Unpickler
// TODO: Figure out why flatMap is necessay for pyspark
iter.flatMap { row =>
unpickle.loads(row) match {
// in case of objects are pickled in batch mode
case objs: java.util.ArrayList[JMap[String, _] @unchecked] => objs.map(_.toMap)
// Incase the partition doesn't have a collection
// not in batch mode
case obj: JMap[String @unchecked, _] => Seq(obj.toMap)
}
}
Expand Down
22 changes: 12 additions & 10 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,14 @@ def __init__(self, sparkContext, sqlContext=None):
...
ValueError:...
>>> allTypes = sc.parallelize([{"int" : 1, "string" : "string", "double" : 1.0, "long": 1L,
... "boolean" : True}])
>>> from datetime import datetime
>>> allTypes = sc.parallelize([{"int": 1, "string": "string", "double": 1.0, "long": 1L,
... "boolean": True, "time": datetime(2010, 1, 1, 1, 1, 1), "dict": {"a": 1},
... "list": [1, 2, 3]}])
>>> srdd = sqlCtx.inferSchema(allTypes).map(lambda x: (x.int, x.string, x.double, x.long,
... x.boolean))
... x.boolean, x.time, x.dict["a"], x.list))
>>> srdd.collect()[0]
(1, u'string', 1.0, 1, True)
(1, u'string', 1.0, 1, True, datetime.datetime(2010, 1, 1, 1, 1, 1), 1, [1, 2, 3])
"""
self._sc = sparkContext
self._jsc = self._sc._jsc
Expand Down Expand Up @@ -88,13 +90,13 @@ def inferSchema(self, rdd):
>>> from array import array
>>> srdd = sqlCtx.inferSchema(nestedRdd1)
>>> srdd.collect() == [{"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}},
... {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}]
>>> srdd.collect() == [{"f1" : [1, 2], "f2" : {"row1" : 1.0}},
... {"f1" : [2, 3], "f2" : {"row2" : 2.0}}]
True
>>> srdd = sqlCtx.inferSchema(nestedRdd2)
>>> srdd.collect() == [{"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)},
... {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}]
>>> srdd.collect() == [{"f1" : [[1, 2], [2, 3]], "f2" : [1, 2]},
... {"f1" : [[2, 3], [3, 4]], "f2" : [2, 3]}]
True
"""
if (rdd.__class__ is SchemaRDD):
Expand Down Expand Up @@ -509,8 +511,8 @@ def _test():
{"f1": array('i', [1, 2]), "f2": {"row1": 1.0}},
{"f1": array('i', [2, 3]), "f2": {"row2": 2.0}}])
globs['nestedRdd2'] = sc.parallelize([
{"f1": [[1, 2], [2, 3]], "f2": set([1, 2]), "f3": (1, 2)},
{"f1": [[2, 3], [3, 4]], "f2": set([2, 3]), "f3": (2, 3)}])
{"f1": [[1, 2], [2, 3]], "f2": [1, 2]},
{"f1": [[2, 3], [3, 4]], "f2": [2, 3]}])
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
if failure_count:
Expand Down
40 changes: 37 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,10 @@ class SQLContext(@transient val sparkContext: SparkContext)
case c: java.lang.Long => LongType
case c: java.lang.Double => DoubleType
case c: java.lang.Boolean => BooleanType
case c: java.math.BigDecimal => DecimalType
case c: java.sql.Timestamp => TimestampType
case c: java.util.Calendar => TimestampType
case c: java.util.List[_] => ArrayType(typeFor(c.head))
case c: java.util.Set[_] => ArrayType(typeFor(c.head))
case c: java.util.Map[_, _] =>
val (key, value) = c.head
MapType(typeFor(key), typeFor(value))
Expand All @@ -362,11 +364,43 @@ class SQLContext(@transient val sparkContext: SparkContext)
ArrayType(typeFor(elem))
case c => throw new Exception(s"Object of type $c cannot be used")
}
val schema = rdd.first().map { case (fieldName, obj) =>
val firstRow = rdd.first()
val schema = firstRow.map { case (fieldName, obj) =>
AttributeReference(fieldName, typeFor(obj), true)()
}.toSeq

val rowRdd = rdd.mapPartitions { iter =>
def needTransform(obj: Any): Boolean = obj match {
case c: java.util.List[_] => true
case c: java.util.Map[_, _] => true
case c if c.getClass.isArray => true
case c: java.util.Calendar => true
case c => false
}

// convert JList, JArray into Seq, convert JMap into Map
// convert Calendar into Timestamp
def transform(obj: Any): Any = obj match {
case c: java.util.List[_] => c.map(transform).toSeq
case c: java.util.Map[_, _] => c.map {
case (key, value) => (key, transform(value))
}.toMap
case c if c.getClass.isArray =>
c.asInstanceOf[Array[_]].map(transform).toSeq
case c: java.util.Calendar =>
new java.sql.Timestamp(c.getTime().getTime())
case c => c
}

val need = firstRow.exists {case (key, value) => needTransform(value)}
val transformed = if (need) {
rdd.mapPartitions { iter =>
iter.map {
m => m.map {case (key, value) => (key, transform(value))}
}
}
} else rdd

val rowRdd = transformed.mapPartitions { iter =>
iter.map { map =>
new GenericRow(map.values.toArray.asInstanceOf[Array[Any]]): Row
}
Expand Down
46 changes: 17 additions & 29 deletions sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
import org.apache.spark.sql.catalyst.types.{ArrayType, BooleanType, StructType}
import org.apache.spark.sql.catalyst.types.{DataType, ArrayType, BooleanType, StructType, MapType}
import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan}
import org.apache.spark.api.java.JavaRDD

Expand Down Expand Up @@ -376,39 +376,27 @@ class SchemaRDD(
* Converts a JavaRDD to a PythonRDD. It is used by pyspark.
*/
private[sql] def javaToPython: JavaRDD[Array[Byte]] = {
def toJava(obj: Any, dataType: DataType): Any = dataType match {
case struct: StructType => rowToMap(obj.asInstanceOf[Row], struct)
case array: ArrayType => obj match {
case seq: Seq[Any] => seq.map(x => toJava(x, array.elementType)).asJava
case list: JList[_] => list.map(x => toJava(x, array.elementType)).asJava
case arr if arr != null && arr.getClass.isArray =>
arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType))
case other => other
}
case mt: MapType => obj.asInstanceOf[Map[_, _]].map {
case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type
}.asJava
// Pyrolite can handle Timestamp
case other => obj
}
def rowToMap(row: Row, structType: StructType): JMap[String, Any] = {
val fields = structType.fields.map(field => (field.name, field.dataType))
val map: JMap[String, Any] = new java.util.HashMap
row.zip(fields).foreach {
case (obj, (attrName, dataType)) =>
dataType match {
case struct: StructType => map.put(attrName, rowToMap(obj.asInstanceOf[Row], struct))
case array @ ArrayType(struct: StructType) =>
val arrayValues = obj match {
case seq: Seq[Any] =>
seq.map(element => rowToMap(element.asInstanceOf[Row], struct)).asJava
case list: JList[_] =>
list.map(element => rowToMap(element.asInstanceOf[Row], struct))
case set: JSet[_] =>
set.map(element => rowToMap(element.asInstanceOf[Row], struct))
case arr if arr != null && arr.getClass.isArray =>
arr.asInstanceOf[Array[Any]].map {
element => rowToMap(element.asInstanceOf[Row], struct)
}
case other => other
}
map.put(attrName, arrayValues)
case array: ArrayType => {
val arrayValues = obj match {
case seq: Seq[Any] => seq.asJava
case other => other
}
map.put(attrName, arrayValues)
}
case other => map.put(attrName, obj)
}
case (obj, (attrName, dataType)) => map.put(attrName, toJava(obj, dataType))
}

map
}

Expand Down

0 comments on commit f0d880e

Please sign in to comment.