Skip to content

Commit

Permalink
[SPARK-2588][SQL] Add some more DSLs.
Browse files Browse the repository at this point in the history
Author: Takuya UESHIN <[email protected]>

Closes apache#1491 from ueshin/issues/SPARK-2588 and squashes the following commits:

43d0a46 [Takuya UESHIN] Merge branch 'master' into issues/SPARK-2588
1023ea0 [Takuya UESHIN] Modify tests to use DSLs.
2310bf1 [Takuya UESHIN] Add some more DSLs.
  • Loading branch information
ueshin authored and marmbrus committed Jul 23, 2014
1 parent f776bc9 commit 1b790cf
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,24 @@ package object dsl {
def === (other: Expression) = EqualTo(expr, other)
def !== (other: Expression) = Not(EqualTo(expr, other))

def in(list: Expression*) = In(expr, list)

def like(other: Expression) = Like(expr, other)
def rlike(other: Expression) = RLike(expr, other)
def contains(other: Expression) = Contains(expr, other)
def startsWith(other: Expression) = StartsWith(expr, other)
def endsWith(other: Expression) = EndsWith(expr, other)
def substr(pos: Expression, len: Expression = Literal(Int.MaxValue)) =
Substring(expr, pos, len)
def substring(pos: Expression, len: Expression = Literal(Int.MaxValue)) =
Substring(expr, pos, len)

def isNull = IsNull(expr)
def isNotNull = IsNotNull(expr)

def getItem(ordinal: Expression) = GetItem(expr, ordinal)
def getField(fieldName: String) = GetField(expr, fieldName)

def cast(to: DataType) = Cast(expr, to)

def asc = SortOrder(expr, Ascending)
Expand Down Expand Up @@ -112,6 +128,7 @@ package object dsl {
def sumDistinct(e: Expression) = SumDistinct(e)
def count(e: Expression) = Count(e)
def countDistinct(e: Expression*) = CountDistinct(e)
def approxCountDistinct(e: Expression, rsd: Double = 0.05) = ApproxCountDistinct(e, rsd)
def avg(e: Expression) = Average(e)
def first(e: Expression) = First(e)
def min(e: Expression) = Min(e)
Expand Down Expand Up @@ -163,6 +180,18 @@ package object dsl {

/** Creates a new AttributeReference of type binary */
def binary = AttributeReference(s, BinaryType, nullable = true)()

/** Creates a new AttributeReference of type array */
def array(dataType: DataType) = AttributeReference(s, ArrayType(dataType), nullable = true)()

/** Creates a new AttributeReference of type map */
def map(keyType: DataType, valueType: DataType): AttributeReference =
map(MapType(keyType, valueType))
def map(mapType: MapType) = AttributeReference(s, mapType, nullable = true)()

/** Creates a new AttributeReference of type struct */
def struct(fields: StructField*): AttributeReference = struct(StructType(fields))
def struct(structType: StructType) = AttributeReference(s, structType, nullable = true)()
}

implicit class DslAttribute(a: AttributeReference) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,17 +301,17 @@ class ExpressionEvaluationSuite extends FunSuite {
val c3 = 'a.boolean.at(2)
val c4 = 'a.boolean.at(3)

checkEvaluation(IsNull(c1), false, row)
checkEvaluation(IsNotNull(c1), true, row)
checkEvaluation(c1.isNull, false, row)
checkEvaluation(c1.isNotNull, true, row)

checkEvaluation(IsNull(c2), true, row)
checkEvaluation(IsNotNull(c2), false, row)
checkEvaluation(c2.isNull, true, row)
checkEvaluation(c2.isNotNull, false, row)

checkEvaluation(IsNull(Literal(1, ShortType)), false)
checkEvaluation(IsNotNull(Literal(1, ShortType)), true)
checkEvaluation(Literal(1, ShortType).isNull, false)
checkEvaluation(Literal(1, ShortType).isNotNull, true)

checkEvaluation(IsNull(Literal(null, ShortType)), true)
checkEvaluation(IsNotNull(Literal(null, ShortType)), false)
checkEvaluation(Literal(null, ShortType).isNull, true)
checkEvaluation(Literal(null, ShortType).isNotNull, false)

checkEvaluation(Coalesce(c1 :: c2 :: Nil), "^Ba*n", row)
checkEvaluation(Coalesce(Literal(null, StringType) :: Nil), null, row)
Expand All @@ -326,11 +326,11 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation(If(Literal(false, BooleanType),
Literal("a", StringType), Literal("b", StringType)), "b", row)

checkEvaluation(In(c1, c1 :: c2 :: Nil), true, row)
checkEvaluation(In(Literal("^Ba*n", StringType),
Literal("^Ba*n", StringType) :: Nil), true, row)
checkEvaluation(In(Literal("^Ba*n", StringType),
Literal("^Ba*n", StringType) :: c2 :: Nil), true, row)
checkEvaluation(c1 in (c1, c2), true, row)
checkEvaluation(
Literal("^Ba*n", StringType) in (Literal("^Ba*n", StringType)), true, row)
checkEvaluation(
Literal("^Ba*n", StringType) in (Literal("^Ba*n", StringType), c2), true, row)
}

test("case when") {
Expand Down Expand Up @@ -420,6 +420,10 @@ class ExpressionEvaluationSuite extends FunSuite {

assert(GetField(Literal(null, typeS), "a").nullable === true)
assert(GetField(Literal(null, typeS_notNullable), "a").nullable === true)

checkEvaluation('c.map(typeMap).at(3).getItem("aa"), "bb", row)
checkEvaluation('c.array(typeArray.elementType).at(4).getItem(1), "bb", row)
checkEvaluation('c.struct(typeS).at(2).getField("a"), "aa", row)
}

test("arithmetic") {
Expand Down Expand Up @@ -472,20 +476,20 @@ class ExpressionEvaluationSuite extends FunSuite {
val c1 = 'a.string.at(0)
val c2 = 'a.string.at(1)

checkEvaluation(Contains(c1, "b"), true, row)
checkEvaluation(Contains(c1, "x"), false, row)
checkEvaluation(Contains(c2, "b"), null, row)
checkEvaluation(Contains(c1, Literal(null, StringType)), null, row)
checkEvaluation(c1 contains "b", true, row)
checkEvaluation(c1 contains "x", false, row)
checkEvaluation(c2 contains "b", null, row)
checkEvaluation(c1 contains Literal(null, StringType), null, row)

checkEvaluation(StartsWith(c1, "a"), true, row)
checkEvaluation(StartsWith(c1, "b"), false, row)
checkEvaluation(StartsWith(c2, "a"), null, row)
checkEvaluation(StartsWith(c1, Literal(null, StringType)), null, row)
checkEvaluation(c1 startsWith "a", true, row)
checkEvaluation(c1 startsWith "b", false, row)
checkEvaluation(c2 startsWith "a", null, row)
checkEvaluation(c1 startsWith Literal(null, StringType), null, row)

checkEvaluation(EndsWith(c1, "c"), true, row)
checkEvaluation(EndsWith(c1, "b"), false, row)
checkEvaluation(EndsWith(c2, "b"), null, row)
checkEvaluation(EndsWith(c1, Literal(null, StringType)), null, row)
checkEvaluation(c1 endsWith "c", true, row)
checkEvaluation(c1 endsWith "b", false, row)
checkEvaluation(c2 endsWith "b", null, row)
checkEvaluation(c1 endsWith Literal(null, StringType), null, row)
}

test("Substring") {
Expand Down Expand Up @@ -542,5 +546,10 @@ class ExpressionEvaluationSuite extends FunSuite {
assert(Substring(s_notNull, Literal(0, IntegerType), Literal(2, IntegerType)).nullable === false)
assert(Substring(s_notNull, Literal(null, IntegerType), Literal(2, IntegerType)).nullable === true)
assert(Substring(s_notNull, Literal(0, IntegerType), Literal(null, IntegerType)).nullable === true)

checkEvaluation(s.substr(0, 2), "ex", row)
checkEvaluation(s.substr(0), "example", row)
checkEvaluation(s.substring(0, 2), "ex", row)
checkEvaluation(s.substring(0), "example", row)
}
}
15 changes: 7 additions & 8 deletions sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark.sql

import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.test._

/* Implicits */
Expand All @@ -41,15 +40,15 @@ class DslQuerySuite extends QueryTest {

test("agg") {
checkAnswer(
testData2.groupBy('a)('a, Sum('b)),
testData2.groupBy('a)('a, sum('b)),
Seq((1,3),(2,3),(3,3))
)
checkAnswer(
testData2.groupBy('a)('a, Sum('b) as 'totB).aggregate(Sum('totB)),
testData2.groupBy('a)('a, sum('b) as 'totB).aggregate(sum('totB)),
9
)
checkAnswer(
testData2.aggregate(Sum('b)),
testData2.aggregate(sum('b)),
9
)
}
Expand Down Expand Up @@ -104,19 +103,19 @@ class DslQuerySuite extends QueryTest {
Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2)))

checkAnswer(
arrayData.orderBy(GetItem('data, 0).asc),
arrayData.orderBy('data.getItem(0).asc),
arrayData.collect().sortBy(_.data(0)).toSeq)

checkAnswer(
arrayData.orderBy(GetItem('data, 0).desc),
arrayData.orderBy('data.getItem(0).desc),
arrayData.collect().sortBy(_.data(0)).reverse.toSeq)

checkAnswer(
mapData.orderBy(GetItem('data, 1).asc),
mapData.orderBy('data.getItem(1).asc),
mapData.collect().sortBy(_.data(1)).toSeq)

checkAnswer(
mapData.orderBy(GetItem('data, 1).desc),
mapData.orderBy('data.getItem(1).desc),
mapData.collect().sortBy(_.data(1)).reverse.toSeq)
}

Expand Down

0 comments on commit 1b790cf

Please sign in to comment.