Skip to content

Commit

Permalink
[SPARK-12443][SQL] encoderFor should support Decimal
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

JIRA: https://issues.apache.org/jira/browse/SPARK-12443

`constructorFor` will call `dataTypeFor` to determine if a type is `ObjectType` or not. If there is not case for `Decimal`, it will be recognized as `ObjectType` and causes the bug.

## How was this patch tested?

Test is added into `ExpressionEncoderSuite`.

Author: Liang-Chi Hsieh <[email protected]>
Author: Liang-Chi Hsieh <[email protected]>

Closes apache#10399 from viirya/fix-encoder-decimal.
  • Loading branch information
viirya authored and marmbrus committed Mar 25, 2016
1 parent 11fa874 commit ca00335
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< definitions.ByteTpe => ByteType
case t if t <:< definitions.BooleanTpe => BooleanType
case t if t <:< localTypeOf[Array[Byte]] => BinaryType
case t if t <:< localTypeOf[Decimal] => DecimalType.SYSTEM_DEFAULT
case _ =>
val className = getClassNameFromType(tpe)
className match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ object RowEncoder {
StaticInvoke(
Decimal.getClass,
DecimalType.SYSTEM_DEFAULT,
"apply",
"fromDecimal",
inputObject :: Nil)

case StringType =>
Expand All @@ -95,7 +95,7 @@ object RowEncoder {
classOf[GenericArrayData],
inputObject :: Nil,
dataType = t)
case _ => MapObjects(extractorsFor(_, et), inputObject, externalDataTypeFor(et))
case _ => MapObjects(extractorsFor(_, et), inputObject, externalDataTypeForInput(et))
}

case t @ MapType(kt, vt, valueNullable) =>
Expand Down Expand Up @@ -129,14 +129,29 @@ object RowEncoder {
Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil),
Literal.create(null, f.dataType),
extractorsFor(
Invoke(inputObject, method, externalDataTypeFor(f.dataType), Literal(i) :: Nil),
Invoke(inputObject, method, externalDataTypeForInput(f.dataType), Literal(i) :: Nil),
f.dataType))
}
If(IsNull(inputObject),
Literal.create(null, inputType),
CreateStruct(convertedFields))
}

/**
* Returns the `DataType` that can be used when generating code that converts input data
* into the Spark SQL internal format. Unlike `externalDataTypeFor`, the `DataType` returned
* by this function can be more permissive since multiple external types may map to a single
* internal type. For example, for an input with DecimalType in external row, its external types
* can be `scala.math.BigDecimal`, `java.math.BigDecimal`, or
* `org.apache.spark.sql.types.Decimal`.
*/
private def externalDataTypeForInput(dt: DataType): DataType = dt match {
// In order to support both Decimal and java BigDecimal in external row, we make this
// as java.lang.Object.
case _: DecimalType => ObjectType(classOf[java.lang.Object])
case _ => externalDataTypeFor(dt)
}

private def externalDataTypeFor(dt: DataType): DataType = dt match {
case _ if ScalaReflection.isNativeType(dt) => dt
case CalendarIntervalType => dt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,14 @@ object Decimal {

def apply(value: String): Decimal = new Decimal().set(BigDecimal(value))

// This is used for RowEncoder to handle Decimal inside external row.
def fromDecimal(value: Any): Decimal = {
value match {
case j: java.math.BigDecimal => apply(j)
case d: Decimal => d
}
}

/**
* Creates a decimal from unscaled, precision and scale without checking the bounds.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.types.{ArrayType, ObjectType, StructType}
import org.apache.spark.sql.types.{ArrayType, Decimal, ObjectType, StructType}

case class RepeatedStruct(s: Seq[PrimitiveData])

Expand Down Expand Up @@ -101,6 +101,8 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
encodeDecodeTest(BigDecimal("32131413.211321313"), "scala decimal")
// encodeDecodeTest(new java.math.BigDecimal("231341.23123"), "java decimal")

encodeDecodeTest(Decimal("32131413.211321313"), "catalyst decimal")

encodeDecodeTest("hello", "string")
encodeDecodeTest(Date.valueOf("2012-12-23"), "date")
encodeDecodeTest(Timestamp.valueOf("2016-01-29 10:00:00"), "timestamp")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,23 @@ class RowEncoderSuite extends SparkFunSuite {
assert(input.getStruct(0) == convertedBack.getStruct(0))
}

test("encode/decode Decimal") {
val schema = new StructType()
.add("int", IntegerType)
.add("string", StringType)
.add("double", DoubleType)
.add("decimal", DecimalType.SYSTEM_DEFAULT)

val encoder = RowEncoder(schema)

val input: Row = Row(100, "test", 0.123, Decimal(1234.5678))
val row = encoder.toRow(input)
val convertedBack = encoder.fromRow(row)
// Decimal inside external row will be converted back to Java BigDecimal when decoding.
assert(input.get(3).asInstanceOf[Decimal].toJavaBigDecimal
.compareTo(convertedBack.getDecimal(3)) == 0)
}

private def encodeDecodeTest(schema: StructType): Unit = {
test(s"encode/decode: ${schema.simpleString}") {
val encoder = RowEncoder(schema)
Expand Down

0 comments on commit ca00335

Please sign in to comment.