Skip to content

Commit

Permalink
[SPARK-6548] Adding stddev to DataFrame functions
Browse files Browse the repository at this point in the history
Adding STDDEV support for DataFrame using 1-pass online /parallel algorithm to compute variance. Please review the code change.

Author: JihongMa <[email protected]>
Author: Jihong MA <[email protected]>
Author: Jihong MA <[email protected]>
Author: Jihong MA <[email protected]>

Closes apache#6297 from JihongMA/SPARK-SQL.
  • Loading branch information
JihongMA authored and davies committed Sep 12, 2015
1 parent 22730ad commit f4a2280
Show file tree
Hide file tree
Showing 16 changed files with 574 additions and 64 deletions.
2 changes: 1 addition & 1 deletion R/pkg/inst/tests/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -1147,7 +1147,7 @@ test_that("describe() and summarize() on a DataFrame", {
stats <- describe(df, "age")
expect_equal(collect(stats)[1, "summary"], "count")
expect_equal(collect(stats)[2, "age"], "24.5")
expect_equal(collect(stats)[3, "age"], "5.5")
expect_equal(collect(stats)[3, "age"], "7.7781745930520225")
stats <- describe(df)
expect_equal(collect(stats)[4, "name"], "Andy")
expect_equal(collect(stats)[5, "age"], "30")
Expand Down
36 changes: 18 additions & 18 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,25 +653,25 @@ def describe(self, *cols):
guarantee about the backward compatibility of the schema of the resulting DataFrame.
>>> df.describe().show()
+-------+---+
|summary|age|
+-------+---+
| count| 2|
| mean|3.5|
| stddev|1.5|
| min| 2|
| max| 5|
+-------+---+
+-------+------------------+
|summary| age|
+-------+------------------+
| count| 2|
| mean| 3.5|
| stddev|2.1213203435596424|
| min| 2|
| max| 5|
+-------+------------------+
>>> df.describe(['age', 'name']).show()
+-------+---+-----+
|summary|age| name|
+-------+---+-----+
| count| 2| 2|
| mean|3.5| null|
| stddev|1.5| null|
| min| 2|Alice|
| max| 5| Bob|
+-------+---+-----+
+-------+------------------+-----+
|summary| age| name|
+-------+------------------+-----+
| count| 2| 2|
| mean| 3.5| null|
| stddev|2.1213203435596424| null|
| min| 2|Alice|
| max| 5| Bob|
+-------+------------------+-----+
"""
if len(cols) == 1 and isinstance(cols[0], list):
cols = cols[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,9 @@ object FunctionRegistry {
expression[Last]("last"),
expression[Max]("max"),
expression[Min]("min"),
expression[Stddev]("stddev"),
expression[StddevPop]("stddev_pop"),
expression[StddevSamp]("stddev_samp"),
expression[Sum]("sum"),

// string functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,9 @@ object HiveTypeCoercion {
case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
case SumDistinct(e @ StringType()) => Sum(Cast(e, DoubleType))
case Average(e @ StringType()) => Average(Cast(e, DoubleType))
case Stddev(e @ StringType()) => Stddev(Cast(e, DoubleType))
case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType))
case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ package object dsl {
def lower(e: Expression): Expression = Lower(e)
def sqrt(e: Expression): Expression = Sqrt(e)
def abs(e: Expression): Expression = Abs(e)
def stddev(e: Expression): Expression = Stddev(e)
def stddev_pop(e: Expression): Expression = StddevPop(e)
def stddev_samp(e: Expression): Expression = StddevSamp(e)

implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name }
// TODO more implicit class for literal?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,149 @@ case class Min(child: Expression) extends AlgebraicAggregate {
override val evaluateExpression = min
}

// Compute the sample standard deviation of a column
case class Stddev(child: Expression) extends StddevAgg(child) {

override def isSample: Boolean = true
override def prettyName: String = "stddev"
}

// Compute the population standard deviation of a column
case class StddevPop(child: Expression) extends StddevAgg(child) {

override def isSample: Boolean = false
override def prettyName: String = "stddev_pop"
}

// Compute the sample standard deviation of a column
case class StddevSamp(child: Expression) extends StddevAgg(child) {

override def isSample: Boolean = true
override def prettyName: String = "stddev_samp"
}

// Compute standard deviation based on online algorithm specified here:
// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
abstract class StddevAgg(child: Expression) extends AlgebraicAggregate {

override def children: Seq[Expression] = child :: Nil

override def nullable: Boolean = true

def isSample: Boolean

// Return data type.
override def dataType: DataType = resultType

// Expected input data type.
// TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the
// new version at planning time (after analysis phase). For now, NullType is added at here
// to make it resolved when we have cases like `select stddev(null)`.
// We can use our analyzer to cast NullType to the default data type of the NumericType once
// we remove the old aggregate functions. Then, we will not need NullType at here.
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType))

private val resultType = DoubleType

private val preCount = AttributeReference("preCount", resultType)()
private val currentCount = AttributeReference("currentCount", resultType)()
private val preAvg = AttributeReference("preAvg", resultType)()
private val currentAvg = AttributeReference("currentAvg", resultType)()
private val currentMk = AttributeReference("currentMk", resultType)()

override val bufferAttributes = preCount :: currentCount :: preAvg ::
currentAvg :: currentMk :: Nil

override val initialValues = Seq(
/* preCount = */ Cast(Literal(0), resultType),
/* currentCount = */ Cast(Literal(0), resultType),
/* preAvg = */ Cast(Literal(0), resultType),
/* currentAvg = */ Cast(Literal(0), resultType),
/* currentMk = */ Cast(Literal(0), resultType)
)

override val updateExpressions = {

// update average
// avg = avg + (value - avg)/count
def avgAdd: Expression = {
currentAvg + ((Cast(child, resultType) - currentAvg) / currentCount)
}

// update sum of square of difference from mean
// Mk = Mk + (value - preAvg) * (value - updatedAvg)
def mkAdd: Expression = {
val delta1 = Cast(child, resultType) - preAvg
val delta2 = Cast(child, resultType) - currentAvg
currentMk + (delta1 * delta2)
}

Seq(
/* preCount = */ If(IsNull(child), preCount, currentCount),
/* currentCount = */ If(IsNull(child), currentCount,
Add(currentCount, Cast(Literal(1), resultType))),
/* preAvg = */ If(IsNull(child), preAvg, currentAvg),
/* currentAvg = */ If(IsNull(child), currentAvg, avgAdd),
/* currentMk = */ If(IsNull(child), currentMk, mkAdd)
)
}

override val mergeExpressions = {

// count merge
def countMerge: Expression = {
currentCount.left + currentCount.right
}

// average merge
def avgMerge: Expression = {
((currentAvg.left * preCount) + (currentAvg.right * currentCount.right)) /
(preCount + currentCount.right)
}

// update sum of square differences
def mkMerge: Expression = {
val avgDelta = currentAvg.right - preAvg
val mkDelta = (avgDelta * avgDelta) * (preCount * currentCount.right) /
(preCount + currentCount.right)

currentMk.left + currentMk.right + mkDelta
}

Seq(
/* preCount = */ If(IsNull(currentCount.left),
Cast(Literal(0), resultType), currentCount.left),
/* currentCount = */ If(IsNull(currentCount.left), currentCount.right,
If(IsNull(currentCount.right), currentCount.left, countMerge)),
/* preAvg = */ If(IsNull(currentAvg.left), Cast(Literal(0), resultType), currentAvg.left),
/* currentAvg = */ If(IsNull(currentAvg.left), currentAvg.right,
If(IsNull(currentAvg.right), currentAvg.left, avgMerge)),
/* currentMk = */ If(IsNull(currentMk.left), currentMk.right,
If(IsNull(currentMk.right), currentMk.left, mkMerge))
)
}

override val evaluateExpression = {
// when currentCount == 0, return null
// when currentCount == 1, return 0
// when currentCount >1
// stddev_samp = sqrt (currentMk/(currentCount -1))
// stddev_pop = sqrt (currentMk/currentCount)
val varCol = {
if (isSample) {
currentMk / Cast((currentCount - Cast(Literal(1), resultType)), resultType)
}
else {
currentMk / currentCount
}
}

If(EqualTo(currentCount, Cast(Literal(0), resultType)), Cast(Literal(null), resultType),
If(EqualTo(currentCount, Cast(Literal(1), resultType)), Cast(Literal(0), resultType),
Cast(Sqrt(varCol), resultType)))
}
}

case class Sum(child: Expression) extends AlgebraicAggregate {

override def children: Seq[Expression] = child :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,24 @@ object Utils {
mode = aggregate.Complete,
isDistinct = false)

case expressions.Stddev(child) =>
aggregate.AggregateExpression2(
aggregateFunction = aggregate.Stddev(child),
mode = aggregate.Complete,
isDistinct = false)

case expressions.StddevPop(child) =>
aggregate.AggregateExpression2(
aggregateFunction = aggregate.StddevPop(child),
mode = aggregate.Complete,
isDistinct = false)

case expressions.StddevSamp(child) =>
aggregate.AggregateExpression2(
aggregateFunction = aggregate.StddevSamp(child),
mode = aggregate.Complete,
isDistinct = false)

case expressions.Sum(child) =>
aggregate.AggregateExpression2(
aggregateFunction = aggregate.Sum(child),
Expand Down
Loading

0 comments on commit f4a2280

Please sign in to comment.