Skip to content

Commit

Permalink
[SPARK-11510][SQL] Remove SQL aggregation tests for higher order stat…
Browse files Browse the repository at this point in the history
…istics

We have some aggregate function tests in both DataFrameAggregateSuite and SQLQuerySuite. The two have almost the same coverage and we should just remove the SQL one.

Author: Reynold Xin <[email protected]>

Closes apache#9475 from rxin/SPARK-11510.
  • Loading branch information
rxin authored and yhuai committed Nov 5, 2015
1 parent 411ff6a commit b6e0a5a
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 147 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,8 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {

test("average") {
checkAnswer(
testData2.agg(avg('a)),
Row(2.0))

// Also check mean
checkAnswer(
testData2.agg(mean('a)),
Row(2.0))
testData2.agg(avg('a), mean('a)),
Row(2.0, 2.0))

checkAnswer(
testData2.agg(avg('a), sumDistinct('a)), // non-partial
Expand All @@ -98,6 +93,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
checkAnswer(
decimalData.agg(avg('a)),
Row(new java.math.BigDecimal(2.0)))

checkAnswer(
decimalData.agg(avg('a), sumDistinct('a)), // non-partial
Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil)
Expand Down Expand Up @@ -168,44 +164,23 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {

test("zero count") {
val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
assert(emptyTableData.count() === 0)

checkAnswer(
emptyTableData.agg(count('a), sumDistinct('a)), // non-partial
Row(0, null))
}

test("stddev") {
val testData2ADev = math.sqrt(4/5.0)

val testData2ADev = math.sqrt(4 / 5.0)
checkAnswer(
testData2.agg(stddev('a)),
Row(testData2ADev))

checkAnswer(
testData2.agg(stddev_pop('a)),
Row(math.sqrt(4/6.0)))

checkAnswer(
testData2.agg(stddev_samp('a)),
Row(testData2ADev))
testData2.agg(stddev('a), stddev_pop('a), stddev_samp('a)),
Row(testData2ADev, math.sqrt(4 / 6.0), testData2ADev))
}

test("zero stddev") {
val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
assert(emptyTableData.count() == 0)

checkAnswer(
emptyTableData.agg(stddev('a)),
Row(null))

checkAnswer(
emptyTableData.agg(stddev_pop('a)),
Row(null))

checkAnswer(
emptyTableData.agg(stddev_samp('a)),
Row(null))
emptyTableData.agg(stddev('a), stddev_pop('a), stddev_samp('a)),
Row(null, null, null))
}

test("zero sum") {
Expand All @@ -227,6 +202,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {

val sparkVariance = testData2.agg(variance('a))
checkAggregatesWithTol(sparkVariance, Row(4.0 / 5.0), absTol)

val sparkVariancePop = testData2.agg(var_pop('a))
checkAggregatesWithTol(sparkVariancePop, Row(4.0 / 6.0), absTol)

Expand All @@ -241,52 +217,35 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
}

test("zero moments") {
val emptyTableData = Seq((1, 2)).toDF("a", "b")
assert(emptyTableData.count() === 1)

checkAnswer(
emptyTableData.agg(variance('a)),
Row(Double.NaN))

checkAnswer(
emptyTableData.agg(var_samp('a)),
Row(Double.NaN))

val input = Seq((1, 2)).toDF("a", "b")
checkAnswer(
emptyTableData.agg(var_pop('a)),
Row(0.0))
input.agg(variance('a), var_samp('a), var_pop('a), skewness('a), kurtosis('a)),
Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN))

checkAnswer(
emptyTableData.agg(skewness('a)),
Row(Double.NaN))

checkAnswer(
emptyTableData.agg(kurtosis('a)),
Row(Double.NaN))
input.agg(
expr("variance(a)"),
expr("var_samp(a)"),
expr("var_pop(a)"),
expr("skewness(a)"),
expr("kurtosis(a)")),
Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN))
}

test("null moments") {
val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
assert(emptyTableData.count() === 0)

checkAnswer(
emptyTableData.agg(variance('a)),
Row(Double.NaN))

checkAnswer(
emptyTableData.agg(var_samp('a)),
Row(Double.NaN))

checkAnswer(
emptyTableData.agg(var_pop('a)),
Row(Double.NaN))

checkAnswer(
emptyTableData.agg(skewness('a)),
Row(Double.NaN))
emptyTableData.agg(variance('a), var_samp('a), var_pop('a), skewness('a), kurtosis('a)),
Row(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN))

checkAnswer(
emptyTableData.agg(kurtosis('a)),
Row(Double.NaN))
emptyTableData.agg(
expr("variance(a)"),
expr("var_samp(a)"),
expr("var_pop(a)"),
expr("skewness(a)"),
expr("kurtosis(a)")),
Row(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN))
}
}
77 changes: 0 additions & 77 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -726,83 +726,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
}

test("stddev") {
checkAnswer(
sql("SELECT STDDEV(a) FROM testData2"),
Row(math.sqrt(4.0 / 5.0))
)
}

test("stddev_pop") {
checkAnswer(
sql("SELECT STDDEV_POP(a) FROM testData2"),
Row(math.sqrt(4.0 / 6.0))
)
}

test("stddev_samp") {
checkAnswer(
sql("SELECT STDDEV_SAMP(a) FROM testData2"),
Row(math.sqrt(4/5.0))
)
}

test("var_samp") {
val absTol = 1e-8
val sparkAnswer = sql("SELECT VAR_SAMP(a) FROM testData2")
val expectedAnswer = Row(4.0 / 5.0)
checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol)
}

test("variance") {
val absTol = 1e-8
val sparkAnswer = sql("SELECT VARIANCE(a) FROM testData2")
val expectedAnswer = Row(0.8)
checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol)
}

test("var_pop") {
val absTol = 1e-8
val sparkAnswer = sql("SELECT VAR_POP(a) FROM testData2")
val expectedAnswer = Row(4.0 / 6.0)
checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol)
}

test("skewness") {
val absTol = 1e-8
val sparkAnswer = sql("SELECT skewness(a) FROM testData2")
val expectedAnswer = Row(0.0)
checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol)
}

test("kurtosis") {
val absTol = 1e-8
val sparkAnswer = sql("SELECT kurtosis(a) FROM testData2")
val expectedAnswer = Row(-1.5)
checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol)
}

test("stddev agg") {
checkAnswer(
sql("SELECT a, stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2 GROUP BY a"),
(1 to 3).map(i => Row(i, math.sqrt(1.0 / 2.0), math.sqrt(1.0 / 4.0), math.sqrt(1.0 / 2.0))))
}

test("variance agg") {
val absTol = 1e-8
checkAggregatesWithTol(
sql("SELECT a, variance(b), var_samp(b), var_pop(b) FROM testData2 GROUP BY a"),
(1 to 3).map(i => Row(i, 1.0 / 2.0, 1.0 / 2.0, 1.0 / 4.0)),
absTol)
}

test("skewness and kurtosis agg") {
val absTol = 1e-8
val sparkAnswer = sql("SELECT a, skewness(b), kurtosis(b) FROM testData2 GROUP BY a")
val expectedAnswer = (1 to 3).map(i => Row(i, 0.0, -2.0))
checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol)
}

test("inner join where, one match per row") {
checkAnswer(
sql("SELECT * FROM upperCaseData JOIN lowerCaseData WHERE n = N"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.sql

import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.Decimal


class StringFunctionsSuite extends QueryTest with SharedSQLContext {
Expand Down

0 comments on commit b6e0a5a

Please sign in to comment.