Skip to content

Commit

Permalink
[SPARK-8176] [SPARK-8197] [SQL] function to_date/ trunc
Browse files Browse the repository at this point in the history
This PR is based on apache#6988 , thanks to adrian-wang .

This brings two SQL functions: to_date() and trunc().

Closes apache#6988

Author: Daoyuan Wang <[email protected]>
Author: Davies Liu <[email protected]>

Closes apache#7805 from davies/to_date and squashes the following commits:

2c7beba [Davies Liu] Merge branch 'master' of github.com:apache/spark into to_date
310dd55 [Daoyuan Wang] remove dup test in rebase
980b092 [Daoyuan Wang] resolve rebase conflict
a476c5a [Daoyuan Wang] address comments from davies
d44ea5f [Daoyuan Wang] function to_date, trunc
  • Loading branch information
adrian-wang authored and rxin committed Jul 31, 2015
1 parent 9307f56 commit 83670fc
Show file tree
Hide file tree
Showing 8 changed files with 245 additions and 2 deletions.
30 changes: 30 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,36 @@ def months_between(date1, date2):
return Column(sc._jvm.functions.months_between(_to_java_column(date1), _to_java_column(date2)))


@since(1.5)
def to_date(col):
"""
Converts the column of StringType or TimestampType into DateType.
>>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00',)], ['t'])
>>> df.select(to_date(df.t).alias('date')).collect()
[Row(date=datetime.date(1997, 2, 28))]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.to_date(_to_java_column(col)))


@since(1.5)
def trunc(date, format):
"""
Returns date truncated to the unit specified by the format.
:param format: 'year', 'YYYY', 'yy' or 'month', 'mon', 'mm'
>>> df = sqlContext.createDataFrame([('1997-02-28',)], ['d'])
>>> df.select(trunc(df.d, 'year').alias('year')).collect()
[Row(year=datetime.date(1997, 1, 1))]
>>> df.select(trunc(df.d, 'mon').alias('month')).collect()
[Row(month=datetime.date(1997, 2, 1))]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.trunc(_to_java_column(date), format))


@since(1.5)
def size(col):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,8 @@ object FunctionRegistry {
expression[NextDay]("next_day"),
expression[Quarter]("quarter"),
expression[Second]("second"),
expression[ToDate]("to_date"),
expression[TruncDate]("trunc"),
expression[UnixTimestamp]("unix_timestamp"),
expression[WeekOfYear]("weekofyear"),
expression[Year]("year"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,6 @@ case class FromUnixTime(sec: Expression, format: Expression)
})
}
}

}

/**
Expand Down Expand Up @@ -696,3 +695,90 @@ case class MonthsBetween(date1: Expression, date2: Expression)
})
}
}

/**
* Returns the date part of a timestamp or string.
*/
case class ToDate(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {

// Implicit casting of spark will accept string in both date and timestamp format, as
// well as TimestampType.
override def inputTypes: Seq[AbstractDataType] = Seq(DateType)

override def dataType: DataType = DateType

override def eval(input: InternalRow): Any = child.eval(input)

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
defineCodeGen(ctx, ev, d => d)
}
}

/*
* Returns date truncated to the unit specified by the format.
*/
case class TruncDate(date: Expression, format: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
override def left: Expression = date
override def right: Expression = format

override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType)
override def dataType: DataType = DateType
override def prettyName: String = "trunc"

lazy val minItemConst = DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String])

override def eval(input: InternalRow): Any = {
val minItem = if (format.foldable) {
minItemConst
} else {
DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String])
}
if (minItem == -1) {
// unknown format
null
} else {
val d = date.eval(input)
if (d == null) {
null
} else {
DateTimeUtils.truncDate(d.asInstanceOf[Int], minItem)
}
}
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")

if (format.foldable) {
if (minItemConst == -1) {
s"""
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
"""
} else {
val d = date.gen(ctx)
s"""
${d.code}
boolean ${ev.isNull} = ${d.isNull};
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.primitive} = $dtu.truncDate(${d.primitive}, $minItemConst);
}
"""
}
} else {
nullSafeCodeGen(ctx, ev, (dateVal, fmt) => {
val form = ctx.freshName("form")
s"""
int $form = $dtu.parseTruncLevel($fmt);
if ($form == -1) {
${ev.isNull} = true;
} else {
${ev.primitive} = $dtu.truncDate($dateVal, $form);
}
"""
})
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -779,4 +779,38 @@ object DateTimeUtils {
}
date + (lastDayOfMonthInYear - dayInYear)
}

private val TRUNC_TO_YEAR = 1
private val TRUNC_TO_MONTH = 2
private val TRUNC_INVALID = -1

/**
* Returns the trunc date from original date and trunc level.
* Trunc level should be generated using `parseTruncLevel()`, should only be 1 or 2.
*/
def truncDate(d: Int, level: Int): Int = {
if (level == TRUNC_TO_YEAR) {
d - DateTimeUtils.getDayInYear(d) + 1
} else if (level == TRUNC_TO_MONTH) {
d - DateTimeUtils.getDayOfMonth(d) + 1
} else {
throw new Exception(s"Invalid trunc level: $level")
}
}

/**
* Returns the truncate level, could be TRUNC_YEAR, TRUNC_MONTH, or TRUNC_INVALID,
* TRUNC_INVALID means unsupported truncate level.
*/
def parseTruncLevel(format: UTF8String): Int = {
if (format == null) {
TRUNC_INVALID
} else {
format.toString.toUpperCase match {
case "YEAR" | "YYYY" | "YY" => TRUNC_TO_YEAR
case "MON" | "MONTH" | "MM" => TRUNC_TO_MONTH
case _ => TRUNC_INVALID
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,34 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
NextDay(Literal(Date.valueOf("2015-07-23")), Literal.create(null, StringType)), null)
}

test("function to_date") {
checkEvaluation(
ToDate(Literal(Date.valueOf("2015-07-22"))),
DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-22")))
checkEvaluation(ToDate(Literal.create(null, DateType)), null)
}

test("function trunc") {
def testTrunc(input: Date, fmt: String, expected: Date): Unit = {
checkEvaluation(TruncDate(Literal.create(input, DateType), Literal.create(fmt, StringType)),
expected)
checkEvaluation(
TruncDate(Literal.create(input, DateType), NonFoldableLiteral.create(fmt, StringType)),
expected)
}
val date = Date.valueOf("2015-07-22")
Seq("yyyy", "YYYY", "year", "YEAR", "yy", "YY").foreach{ fmt =>
testTrunc(date, fmt, Date.valueOf("2015-01-01"))
}
Seq("month", "MONTH", "mon", "MON", "mm", "MM").foreach { fmt =>
testTrunc(date, fmt, Date.valueOf("2015-07-01"))
}
testTrunc(date, "DD", null)
testTrunc(date, null, null)
testTrunc(null, "MON", null)
testTrunc(null, null, null)
}

test("from_unixtime") {
val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS"
Expand Down Expand Up @@ -405,5 +433,4 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(
UnixTimestamp(Literal("2015-07-24"), Literal("not a valid format")), null)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,8 @@ object NonFoldableLiteral {
val lit = Literal(value)
NonFoldableLiteral(lit.value, lit.dataType)
}
def create(value: Any, dataType: DataType): NonFoldableLiteral = {
val lit = Literal.create(value, dataType)
NonFoldableLiteral(lit.value, lit.dataType)
}
}
16 changes: 16 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2181,6 +2181,22 @@ object functions {
*/
def unix_timestamp(s: Column, p: String): Column = UnixTimestamp(s.expr, Literal(p))

/*
* Converts the column into DateType.
*
* @group datetime_funcs
* @since 1.5.0
*/
def to_date(e: Column): Column = ToDate(e.expr)

/**
* Returns date truncated to the unit specified by the format.
*
* @group datetime_funcs
* @since 1.5.0
*/
def trunc(date: Column, format: String): Column = TruncDate(date.expr, Literal(format))

//////////////////////////////////////////////////////////////////////////////////////////////
// Collection functions
//////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,50 @@ class DateFunctionsSuite extends QueryTest {
Seq(Row(Date.valueOf("2015-07-30")), Row(Date.valueOf("2015-07-30"))))
}

test("function to_date") {
val d1 = Date.valueOf("2015-07-22")
val d2 = Date.valueOf("2015-07-01")
val t1 = Timestamp.valueOf("2015-07-22 10:00:00")
val t2 = Timestamp.valueOf("2014-12-31 23:59:59")
val s1 = "2015-07-22 10:00:00"
val s2 = "2014-12-31"
val df = Seq((d1, t1, s1), (d2, t2, s2)).toDF("d", "t", "s")

checkAnswer(
df.select(to_date(col("t"))),
Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31"))))
checkAnswer(
df.select(to_date(col("d"))),
Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2015-07-01"))))
checkAnswer(
df.select(to_date(col("s"))),
Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31"))))

checkAnswer(
df.selectExpr("to_date(t)"),
Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31"))))
checkAnswer(
df.selectExpr("to_date(d)"),
Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2015-07-01"))))
checkAnswer(
df.selectExpr("to_date(s)"),
Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31"))))
}

test("function trunc") {
val df = Seq(
(1, Timestamp.valueOf("2015-07-22 10:00:00")),
(2, Timestamp.valueOf("2014-12-31 00:00:00"))).toDF("i", "t")

checkAnswer(
df.select(trunc(col("t"), "YY")),
Seq(Row(Date.valueOf("2015-01-01")), Row(Date.valueOf("2014-01-01"))))

checkAnswer(
df.selectExpr("trunc(t, 'Month')"),
Seq(Row(Date.valueOf("2015-07-01")), Row(Date.valueOf("2014-12-01"))))
}

test("from_unixtime") {
val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS"
Expand Down

0 comments on commit 83670fc

Please sign in to comment.