Skip to content

Commit

Permalink
[SPARK-3930] [SPARK-3933] Support fixed-precision decimal in SQL, and…
Browse files Browse the repository at this point in the history
… some optimizations

- Adds optional precision and scale to Spark SQL's decimal type, which behave similarly to those in Hive 13 (https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf)
- Replaces our internal representation of decimals with a Decimal class that can store small values in a mutable Long, saving memory in this situation and letting some operations happen directly on Longs

This is still marked WIP because there are a few TODOs, but I'll remove that tag when done.

Author: Matei Zaharia <[email protected]>

Closes apache#2983 from mateiz/decimal-1 and squashes the following commits:

35e6b02 [Matei Zaharia] Fix issues after merge
227f24a [Matei Zaharia] Review comments
31f915e [Matei Zaharia] Implement Davies's suggestions in Python
eb84820 [Matei Zaharia] Support reading/writing decimals as fixed-length binary in Parquet
4dc6bae [Matei Zaharia] Fix decimal support in PySpark
d1d9d68 [Matei Zaharia] Fix compile error and test issues after rebase
b28933d [Matei Zaharia] Support decimal precision/scale in Hive metastore
2118c0d [Matei Zaharia] Some test and bug fixes
81db9cb [Matei Zaharia] Added mutable Decimal that will be more efficient for small precisions
7af0c3b [Matei Zaharia] Add optional precision and scale to DecimalType, but use Unlimited for now
ec0a947 [Matei Zaharia] Make the result of AVG on Decimals be Decimal, not Double
  • Loading branch information
mateiz authored and marmbrus committed Nov 2, 2014
1 parent 56f2c61 commit 23f966f
Show file tree
Hide file tree
Showing 55 changed files with 1,636 additions and 232 deletions.
35 changes: 32 additions & 3 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import keyword
import warnings
import json
import re
from array import array
from operator import itemgetter
from itertools import imap
Expand Down Expand Up @@ -148,13 +149,30 @@ class TimestampType(PrimitiveType):
"""


class DecimalType(PrimitiveType):
class DecimalType(DataType):

"""Spark SQL DecimalType
The data type representing decimal.Decimal values.
"""

def __init__(self, precision=None, scale=None):
self.precision = precision
self.scale = scale
self.hasPrecisionInfo = precision is not None

def jsonValue(self):
if self.hasPrecisionInfo:
return "decimal(%d,%d)" % (self.precision, self.scale)
else:
return "decimal"

def __repr__(self):
if self.hasPrecisionInfo:
return "DecimalType(%d,%d)" % (self.precision, self.scale)
else:
return "DecimalType()"


class DoubleType(PrimitiveType):

Expand Down Expand Up @@ -446,9 +464,20 @@ def _parse_datatype_json_string(json_string):
return _parse_datatype_json_value(json.loads(json_string))


_FIXED_DECIMAL = re.compile("decimal\\((\\d+),(\\d+)\\)")


def _parse_datatype_json_value(json_value):
if type(json_value) is unicode and json_value in _all_primitive_types.keys():
return _all_primitive_types[json_value]()
if type(json_value) is unicode:
if json_value in _all_primitive_types.keys():
return _all_primitive_types[json_value]()
elif json_value == u'decimal':
return DecimalType()
elif _FIXED_DECIMAL.match(json_value):
m = _FIXED_DECIMAL.match(json_value)
return DecimalType(int(m.group(1)), int(m.group(2)))
else:
raise ValueError("Could not parse datatype: %s" % json_value)
else:
return _all_complex_types[json_value["type"]].fromJson(json_value)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ package org.apache.spark.sql.catalyst

import java.sql.{Date, Timestamp}

import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference, Row}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.catalyst.types.decimal.Decimal

/**
* Provides experimental support for generating catalyst schemas for scala objects.
Expand All @@ -40,9 +41,20 @@ object ScalaReflection {
case s: Seq[_] => s.map(convertToCatalyst)
case m: Map[_, _] => m.map { case (k, v) => convertToCatalyst(k) -> convertToCatalyst(v) }
case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray)
case d: BigDecimal => Decimal(d)
case other => other
}

/** Converts Catalyst types used internally in rows to standard Scala types */
def convertToScala(a: Any): Any = a match {
case s: Seq[_] => s.map(convertToScala)
case m: Map[_, _] => m.map { case (k, v) => convertToScala(k) -> convertToScala(v) }
case d: Decimal => d.toBigDecimal
case other => other
}

def convertRowToScala(r: Row): Row = new GenericRow(r.toArray.map(convertToScala))

/** Returns a Sequence of attributes for the given case class type. */
def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match {
case Schema(s: StructType, _) =>
Expand Down Expand Up @@ -83,7 +95,8 @@ object ScalaReflection {
case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true)
case t if t <:< typeOf[Date] => Schema(DateType, nullable = true)
case t if t <:< typeOf[BigDecimal] => Schema(DecimalType, nullable = true)
case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true)
case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true)
case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true)
case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true)
Expand Down Expand Up @@ -111,8 +124,9 @@ object ScalaReflection {
case obj: LongType.JvmType => LongType
case obj: FloatType.JvmType => FloatType
case obj: DoubleType.JvmType => DoubleType
case obj: DecimalType.JvmType => DecimalType
case obj: DateType.JvmType => DateType
case obj: BigDecimal => DecimalType.Unlimited
case obj: Decimal => DecimalType.Unlimited
case obj: TimestampType.JvmType => TimestampType
case null => NullType
// For other cases, there is no obvious mapping from the type of the given object to a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,13 @@ class SqlParser extends AbstractSparkSQLParser {
protected val CASE = Keyword("CASE")
protected val CAST = Keyword("CAST")
protected val COUNT = Keyword("COUNT")
protected val DECIMAL = Keyword("DECIMAL")
protected val DESC = Keyword("DESC")
protected val DISTINCT = Keyword("DISTINCT")
protected val ELSE = Keyword("ELSE")
protected val END = Keyword("END")
protected val EXCEPT = Keyword("EXCEPT")
protected val DOUBLE = Keyword("DOUBLE")
protected val FALSE = Keyword("FALSE")
protected val FIRST = Keyword("FIRST")
protected val FROM = Keyword("FROM")
Expand Down Expand Up @@ -385,5 +387,15 @@ class SqlParser extends AbstractSparkSQLParser {
}

protected lazy val dataType: Parser[DataType] =
STRING ^^^ StringType | TIMESTAMP ^^^ TimestampType
( STRING ^^^ StringType
| TIMESTAMP ^^^ TimestampType
| DOUBLE ^^^ DoubleType
| fixedDecimalType
| DECIMAL ^^^ DecimalType.Unlimited
)

protected lazy val fixedDecimalType: Parser[DataType] =
(DECIMAL ~ "(" ~> numericLit) ~ ("," ~> numericLit <~ ")") ^^ {
case precision ~ scale => DecimalType(precision.toInt, scale.toInt)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,31 @@ import org.apache.spark.sql.catalyst.types._
object HiveTypeCoercion {
// See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types.
// The conversion for integral and floating point types have a linear widening hierarchy:
val numericPrecedence =
Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType)
val allPromotions: Seq[Seq[DataType]] = numericPrecedence :: Nil
private val numericPrecedence =
Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType.Unlimited)

/**
* Find the tightest common type of two types that might be used in a binary expression.
* This handles all numeric types except fixed-precision decimals interacting with each other or
* with primitive types, because in that case the precision and scale of the result depends on
* the operation. Those rules are implemented in [[HiveTypeCoercion.DecimalPrecision]].
*/
def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = {
val valueTypes = Seq(t1, t2).filter(t => t != NullType)
if (valueTypes.distinct.size > 1) {
// Try and find a promotion rule that contains both types in question.
val applicableConversion =
HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2))

// If found return the widest common type, otherwise None
applicableConversion.map(_.filter(t => t == t1 || t == t2).last)
// Promote numeric types to the highest of the two and all numeric types to unlimited decimal
if (numericPrecedence.contains(t1) && numericPrecedence.contains(t2)) {
Some(numericPrecedence.filter(t => t == t1 || t == t2).last)
} else if (t1.isInstanceOf[DecimalType] && t2.isInstanceOf[DecimalType]) {
// Fixed-precision decimals can up-cast into unlimited
if (t1 == DecimalType.Unlimited || t2 == DecimalType.Unlimited) {
Some(DecimalType.Unlimited)
} else {
None
}
} else {
None
}
} else {
Some(if (valueTypes.size == 0) NullType else valueTypes.head)
}
Expand All @@ -59,6 +71,7 @@ trait HiveTypeCoercion {
ConvertNaNs ::
WidenTypes ::
PromoteStrings ::
DecimalPrecision ::
BooleanComparisons ::
BooleanCasts ::
StringToIntegralCasts ::
Expand Down Expand Up @@ -151,6 +164,7 @@ trait HiveTypeCoercion {
import HiveTypeCoercion._

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// TODO: unions with fixed-precision decimals
case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
val castedInput = left.output.zip(right.output).map {
// When a string is found on one side, make the other side a string too.
Expand Down Expand Up @@ -265,6 +279,110 @@ trait HiveTypeCoercion {
}
}

// scalastyle:off
/**
* Calculates and propagates precision for fixed-precision decimals. Hive has a number of
* rules for this based on the SQL standard and MS SQL:
* https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf
*
* In particular, if we have expressions e1 and e2 with precision/scale p1/s2 and p2/s2
* respectively, then the following operations have the following precision / scale:
*
* Operation Result Precision Result Scale
* ------------------------------------------------------------------------
* e1 + e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2)
* e1 - e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2)
* e1 * e2 p1 + p2 + 1 s1 + s2
* e1 / e2 p1 - s1 + s2 + max(6, s1 + p2 + 1) max(6, s1 + p2 + 1)
* e1 % e2 min(p1-s1, p2-s2) + max(s1, s2) max(s1, s2)
* sum(e1) p1 + 10 s1
* avg(e1) p1 + 4 s1 + 4
*
* Catalyst also has unlimited-precision decimals. For those, all ops return unlimited precision.
*
* To implement the rules for fixed-precision types, we introduce casts to turn them to unlimited
* precision, do the math on unlimited-precision numbers, then introduce casts back to the
* required fixed precision. This allows us to do all rounding and overflow handling in the
* cast-to-fixed-precision operator.
*
* In addition, when mixing non-decimal types with decimals, we use the following rules:
* - BYTE gets turned into DECIMAL(3, 0)
* - SHORT gets turned into DECIMAL(5, 0)
* - INT gets turned into DECIMAL(10, 0)
* - LONG gets turned into DECIMAL(20, 0)
* - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE (this is the same as Hive,
* but note that unlimited decimals are considered bigger than doubles in WidenTypes)
*/
// scalastyle:on
object DecimalPrecision extends Rule[LogicalPlan] {
import scala.math.{max, min}

// Conversion rules for integer types into fixed-precision decimals
val intTypeToFixed: Map[DataType, DecimalType] = Map(
ByteType -> DecimalType(3, 0),
ShortType -> DecimalType(5, 0),
IntegerType -> DecimalType(10, 0),
LongType -> DecimalType(20, 0)
)

def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType

def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
// Skip nodes whose children have not been resolved yet
case e if !e.childrenResolved => e

case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
Cast(
Add(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
)

case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
Cast(
Subtract(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
)

case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
Cast(
Multiply(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
DecimalType(p1 + p2 + 1, s1 + s2)
)

case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
Cast(
Divide(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
DecimalType(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 + 1))
)

case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
Cast(
Remainder(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
)

// Promote integers inside a binary expression with fixed-precision decimals to decimals,
// and fixed-precision decimals in an expression with floats / doubles to doubles
case b: BinaryExpression if b.left.dataType != b.right.dataType =>
(b.left.dataType, b.right.dataType) match {
case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) =>
b.makeCopy(Array(Cast(b.left, intTypeToFixed(t)), b.right))
case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) =>
b.makeCopy(Array(b.left, Cast(b.right, intTypeToFixed(t))))
case (t, DecimalType.Fixed(p, s)) if isFloat(t) =>
b.makeCopy(Array(b.left, Cast(b.right, DoubleType)))
case (DecimalType.Fixed(p, s), t) if isFloat(t) =>
b.makeCopy(Array(Cast(b.left, DoubleType), b.right))
case _ =>
b
}

// TODO: MaxOf, MinOf, etc might want other rules

// SUM and AVERAGE are handled by the implementations of those expressions
}
}

/**
* Changes Boolean values to Bytes so that expressions like true < false can be Evaluated.
*/
Expand Down Expand Up @@ -330,7 +448,7 @@ trait HiveTypeCoercion {
case e if !e.childrenResolved => e

case Cast(e @ StringType(), t: IntegralType) =>
Cast(Cast(e, DecimalType), t)
Cast(Cast(e, DecimalType.Unlimited), t)
}
}

Expand Down Expand Up @@ -383,10 +501,12 @@ trait HiveTypeCoercion {

// Decimal and Double remain the same
case d: Divide if d.resolved && d.dataType == DoubleType => d
case d: Divide if d.resolved && d.dataType == DecimalType => d
case d: Divide if d.resolved && d.dataType.isInstanceOf[DecimalType] => d

case Divide(l, r) if l.dataType == DecimalType => Divide(l, Cast(r, DecimalType))
case Divide(l, r) if r.dataType == DecimalType => Divide(Cast(l, DecimalType), r)
case Divide(l, r) if l.dataType.isInstanceOf[DecimalType] =>
Divide(l, Cast(r, DecimalType.Unlimited))
case Divide(l, r) if r.dataType.isInstanceOf[DecimalType] =>
Divide(Cast(l, DecimalType.Unlimited), r)

case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst

import java.sql.{Date, Timestamp}

import org.apache.spark.sql.catalyst.types.decimal.Decimal

import scala.language.implicitConversions

import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
Expand Down Expand Up @@ -124,7 +126,8 @@ package object dsl {
implicit def doubleToLiteral(d: Double) = Literal(d)
implicit def stringToLiteral(s: String) = Literal(s)
implicit def dateToLiteral(d: Date) = Literal(d)
implicit def decimalToLiteral(d: BigDecimal) = Literal(d)
implicit def bigDecimalToLiteral(d: BigDecimal) = Literal(d)
implicit def decimalToLiteral(d: Decimal) = Literal(d)
implicit def timestampToLiteral(t: Timestamp) = Literal(t)
implicit def binaryToLiteral(a: Array[Byte]) = Literal(a)

Expand Down Expand Up @@ -183,7 +186,11 @@ package object dsl {
def date = AttributeReference(s, DateType, nullable = true)()

/** Creates a new AttributeReference of type decimal */
def decimal = AttributeReference(s, DecimalType, nullable = true)()
def decimal = AttributeReference(s, DecimalType.Unlimited, nullable = true)()

/** Creates a new AttributeReference of type decimal */
def decimal(precision: Int, scale: Int) =
AttributeReference(s, DecimalType(precision, scale), nullable = true)()

/** Creates a new AttributeReference of type timestamp */
def timestamp = AttributeReference(s, TimestampType, nullable = true)()
Expand Down
Loading

0 comments on commit 23f966f

Please sign in to comment.