Skip to content

Commit

Permalink
[SPARK-28077][SQL] Support ANSI SQL OVERLAY function.
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

The `OVERLAY` function is a `ANSI` `SQL`.
For example:
```
SELECT OVERLAY('abcdef' PLACING '45' FROM 4);

SELECT OVERLAY('yabadoo' PLACING 'daba' FROM 5);

SELECT OVERLAY('yabadoo' PLACING 'daba' FROM 5 FOR 0);

SELECT OVERLAY('babosa' PLACING 'ubb' FROM 2 FOR 4);
```
The results of the above four `SQL` are:
```
abc45f
yabadaba
yabadabadoo
bubba
```

Note: If the input string is null, then the result is null too.

There are some mainstream database support the syntax.
**PostgreSQL:**
https://www.postgresql.org/docs/11/functions-string.html

**Vertica:** https://www.vertica.com/docs/9.2.x/HTML/Content/Authoring/SQLReferenceManual/Functions/String/OVERLAY.htm?zoom_highlight=overlay

**Oracle:**
https://docs.oracle.com/en/database/oracle/oracle-database/19/arpls/UTL_RAW.html#GUID-342E37E7-FE43-4CE1-A0E9-7DAABD000369

**DB2:**
https://www.ibm.com/support/knowledgecenter/SSGMCP_5.3.0/com.ibm.cics.rexx.doc/rexx/overlay.html

There are some show of the PR on my production environment.
```
spark-sql> SELECT OVERLAY('abcdef' PLACING '45' FROM 4);
abc45f
Time taken: 6.385 seconds, Fetched 1 row(s)
spark-sql> SELECT OVERLAY('yabadoo' PLACING 'daba' FROM 5);
yabadaba
Time taken: 0.191 seconds, Fetched 1 row(s)
spark-sql> SELECT OVERLAY('yabadoo' PLACING 'daba' FROM 5 FOR 0);
yabadabadoo
Time taken: 0.186 seconds, Fetched 1 row(s)
spark-sql> SELECT OVERLAY('babosa' PLACING 'ubb' FROM 2 FOR 4);
bubba
Time taken: 0.151 seconds, Fetched 1 row(s)
spark-sql> SELECT OVERLAY(null PLACING '45' FROM 4);
NULL
Time taken: 0.22 seconds, Fetched 1 row(s)
spark-sql> SELECT OVERLAY(null PLACING 'daba' FROM 5);
NULL
Time taken: 0.157 seconds, Fetched 1 row(s)
spark-sql> SELECT OVERLAY(null PLACING 'daba' FROM 5 FOR 0);
NULL
Time taken: 0.254 seconds, Fetched 1 row(s)
spark-sql> SELECT OVERLAY(null PLACING 'ubb' FROM 2 FOR 4);
NULL
Time taken: 0.159 seconds, Fetched 1 row(s)
```

## How was this patch tested?

Exists UT and new UT.

Closes apache#24918 from beliefer/ansi-sql-overlay.

Lead-authored-by: gengjiaan <[email protected]>
Co-authored-by: Jiaan Geng <[email protected]>
Signed-off-by: Takuya UESHIN <[email protected]>
  • Loading branch information
2 people authored and ueshin committed Jun 28, 2019
1 parent 31e7c37 commit 832ff87
Show file tree
Hide file tree
Showing 10 changed files with 281 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/sql-keywords.md
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,14 @@ Below is a list of all the keywords in Spark SQL.
<tr><td>OUTPUTFORMAT</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
<tr><td>OVER</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
<tr><td>OVERLAPS</td><td>reserved</td><td>non-reserved</td><td>reserved</td></tr>
<tr><td>OVERLAY</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
<tr><td>OVERWRITE</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
<tr><td>PARTITION</td><td>non-reserved</td><td>non-reserved</td><td>reserved</td></tr>
<tr><td>PARTITIONED</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
<tr><td>PARTITIONS</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
<tr><td>PERCENT</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
<tr><td>PIVOT</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
<tr><td>PLACING</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
<tr><td>POSITION</td><td>non-reserved</td><td>non-reserved</td><td>reserved</td></tr>
<tr><td>PRECEDING</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
<tr><td>PRIMARY</td><td>reserved</td><td>non-reserved</td><td>reserved</td></tr>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,8 @@ primaryExpression
((FOR | ',') len=valueExpression)? ')' #substring
| TRIM '(' trimOption=(BOTH | LEADING | TRAILING)? (trimStr=valueExpression)?
FROM srcStr=valueExpression ')' #trim
| OVERLAY '(' input=valueExpression PLACING replace=valueExpression
FROM position=valueExpression (FOR length=valueExpression)? ')' #overlay
;

constant
Expand Down Expand Up @@ -1002,6 +1004,7 @@ ansiNonReserved
| OUT
| OUTPUTFORMAT
| OVER
| OVERLAY
| OVERWRITE
| PARTITION
| PARTITIONED
Expand Down Expand Up @@ -1253,12 +1256,14 @@ nonReserved
| OUTPUTFORMAT
| OVER
| OVERLAPS
| OVERLAY
| OVERWRITE
| PARTITION
| PARTITIONED
| PARTITIONS
| PERCENTLIT
| PIVOT
| PLACING
| POSITION
| PRECEDING
| PRIMARY
Expand Down Expand Up @@ -1509,12 +1514,14 @@ OUTER: 'OUTER';
OUTPUTFORMAT: 'OUTPUTFORMAT';
OVER: 'OVER';
OVERLAPS: 'OVERLAPS';
OVERLAY: 'OVERLAY';
OVERWRITE: 'OVERWRITE';
PARTITION: 'PARTITION';
PARTITIONED: 'PARTITIONED';
PARTITIONS: 'PARTITIONS';
PERCENTLIT: 'PERCENT';
PIVOT: 'PIVOT';
PLACING: 'PLACING';
POSITION: 'POSITION';
PRECEDING: 'PRECEDING';
PRIMARY: 'PRIMARY';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ object FunctionRegistry {
expression[RegExpReplace]("regexp_replace"),
expression[StringRepeat]("repeat"),
expression[StringReplace]("replace"),
expression[Overlay]("overlay"),
expression[RLike]("rlike"),
expression[StringRPad]("rpad"),
expression[StringTrimRight]("rtrim"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ import org.apache.spark.sql.types._
* - [[UnaryExpression]]: an expression that has one child.
* - [[BinaryExpression]]: an expression that has two children.
* - [[TernaryExpression]]: an expression that has three children.
* - [[QuaternaryExpression]]: an expression that has four children.
* - [[BinaryOperator]]: a special case of [[BinaryExpression]] that requires two children to have
* the same output data type.
*
Expand Down Expand Up @@ -757,6 +758,111 @@ abstract class TernaryExpression extends Expression {
}
}

/**
* An expression with four inputs and one output. The output is by default evaluated to null
* if any input is evaluated to null.
*/
abstract class QuaternaryExpression extends Expression {

override def foldable: Boolean = children.forall(_.foldable)

override def nullable: Boolean = children.exists(_.nullable)

/**
* Default behavior of evaluation according to the default nullability of QuaternaryExpression.
* If subclass of QuaternaryExpression override nullable, probably should also override this.
*/
override def eval(input: InternalRow): Any = {
val exprs = children
val value1 = exprs(0).eval(input)
if (value1 != null) {
val value2 = exprs(1).eval(input)
if (value2 != null) {
val value3 = exprs(2).eval(input)
if (value3 != null) {
val value4 = exprs(3).eval(input)
if (value4 != null) {
return nullSafeEval(value1, value2, value3, value4)
}
}
}
}
null
}

/**
* Called by default [[eval]] implementation. If subclass of QuaternaryExpression keep the
* default nullability, they can override this method to save null-check code. If we need
* full control of evaluation process, we should override [[eval]].
*/
protected def nullSafeEval(input1: Any, input2: Any, input3: Any, input4: Any): Any =
sys.error(s"QuaternaryExpressions must override either eval or nullSafeEval")

/**
* Short hand for generating quaternary evaluation code.
* If either of the sub-expressions is null, the result of this computation
* is assumed to be null.
*
* @param f accepts four variable names and returns Java code to compute the output.
*/
protected def defineCodeGen(
ctx: CodegenContext,
ev: ExprCode,
f: (String, String, String, String) => String): ExprCode = {
nullSafeCodeGen(ctx, ev, (eval1, eval2, eval3, eval4) => {
s"${ev.value} = ${f(eval1, eval2, eval3, eval4)};"
})
}

/**
* Short hand for generating quaternary evaluation code.
* If either of the sub-expressions is null, the result of this computation
* is assumed to be null.
*
* @param f function that accepts the 4 non-null evaluation result names of children
* and returns Java code to compute the output.
*/
protected def nullSafeCodeGen(
ctx: CodegenContext,
ev: ExprCode,
f: (String, String, String, String) => String): ExprCode = {
val firstGen = children(0).genCode(ctx)
val secondGen = children(1).genCode(ctx)
val thridGen = children(2).genCode(ctx)
val fourthGen = children(3).genCode(ctx)
val resultCode = f(firstGen.value, secondGen.value, thridGen.value, fourthGen.value)

if (nullable) {
val nullSafeEval =
firstGen.code + ctx.nullSafeExec(children(0).nullable, firstGen.isNull) {
secondGen.code + ctx.nullSafeExec(children(1).nullable, secondGen.isNull) {
thridGen.code + ctx.nullSafeExec(children(2).nullable, thridGen.isNull) {
fourthGen.code + ctx.nullSafeExec(children(3).nullable, fourthGen.isNull) {
s"""
${ev.isNull} = false; // resultCode could change nullability.
$resultCode
"""
}
}
}
}

ev.copy(code = code"""
boolean ${ev.isNull} = true;
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$nullSafeEval""")
} else {
ev.copy(code = code"""
${firstGen.code}
${secondGen.code}
${thridGen.code}
${fourthGen.code}
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$resultCode""", isNull = FalseLiteral)
}
}
}

/**
* A trait used for resolving nullable flags, including `nullable`, `containsNull` of [[ArrayType]]
* and `valueContainsNull` of [[MapType]], containsNull, valueContainsNull flags of the output date
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.UTF8StringBuilder
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -454,6 +455,69 @@ case class StringReplace(srcExpr: Expression, searchExpr: Expression, replaceExp
override def prettyName: String = "replace"
}

object Overlay {

def calculate(input: UTF8String, replace: UTF8String, pos: Int, len: Int): UTF8String = {
val builder = new UTF8StringBuilder
builder.append(input.substringSQL(1, pos - 1))
builder.append(replace)
// If you specify length, it must be a positive whole number or zero.
// Otherwise it will be ignored.
// The default value for length is the length of replace.
val length = if (len >= 0) {
len
} else {
replace.numChars
}
builder.append(input.substringSQL(pos + length, Int.MaxValue))
builder.build()
}
}

// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(input, replace, pos[, len]) - Replace `input` with `replace` that starts at `pos` and is of length `len`.",
examples = """
Examples:
> SELECT _FUNC_('Spark SQL' PLACING '_' FROM 6);
Spark_SQL
> SELECT _FUNC_('Spark SQL' PLACING 'CORE' FROM 7);
Spark CORE
> SELECT _FUNC_('Spark SQL' PLACING 'ANSI ' FROM 7 FOR 0);
Spark ANSI SQL
> SELECT _FUNC_('Spark SQL' PLACING 'tructured' FROM 2 FOR 4);
Structured SQL
""")
// scalastyle:on line.size.limit
case class Overlay(input: Expression, replace: Expression, pos: Expression, len: Expression)
extends QuaternaryExpression with ImplicitCastInputTypes with NullIntolerant {

def this(str: Expression, replace: Expression, pos: Expression) = {
this(str, replace, pos, Literal.create(-1, IntegerType))
}

override def dataType: DataType = StringType

override def inputTypes: Seq[AbstractDataType] =
Seq(StringType, StringType, IntegerType, IntegerType)

override def children: Seq[Expression] = input :: replace :: pos :: len :: Nil

override def nullSafeEval(inputEval: Any, replaceEval: Any, posEval: Any, lenEval: Any): Any = {
val inputStr = inputEval.asInstanceOf[UTF8String]
val replaceStr = replaceEval.asInstanceOf[UTF8String]
val position = posEval.asInstanceOf[Int]
val length = lenEval.asInstanceOf[Int]
Overlay.calculate(inputStr, replaceStr, position, length)
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, (input, replace, pos, len) =>
"org.apache.spark.sql.catalyst.expressions.Overlay" +
s".calculate($input, $replace, $pos, $len);")
}
}

object StringTranslate {

def buildDict(matchingString: UTF8String, replaceString: UTF8String)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1421,6 +1421,20 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
}
}

/**
* Create a Overlay expression.
*/
override def visitOverlay(ctx: OverlayContext): Expression = withOrigin(ctx) {
val input = expression(ctx.input)
val replace = expression(ctx.replace)
val position = expression(ctx.position)
val lengthOpt = Option(ctx.length).map(expression)
lengthOpt match {
case Some(length) => Overlay(input, replace, position, length)
case None => new Overlay(input, replace, position)
}
}

/**
* Create a (windowed) Function expression.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,30 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
// scalastyle:on
}

test("overlay") {
checkEvaluation(new Overlay(Literal("Spark SQL"), Literal("_"),
Literal.create(6, IntegerType)), "Spark_SQL")
checkEvaluation(new Overlay(Literal("Spark SQL"), Literal("CORE"),
Literal.create(7, IntegerType)), "Spark CORE")
checkEvaluation(Overlay(Literal("Spark SQL"), Literal("ANSI "),
Literal.create(7, IntegerType), Literal.create(0, IntegerType)), "Spark ANSI SQL")
checkEvaluation(Overlay(Literal("Spark SQL"), Literal("tructured"),
Literal.create(2, IntegerType), Literal.create(4, IntegerType)), "Structured SQL")
checkEvaluation(new Overlay(Literal.create(null, StringType), Literal("_"),
Literal.create(6, IntegerType)), null)
checkEvaluation(new Overlay(Literal.create(null, StringType), Literal("CORE"),
Literal.create(7, IntegerType)), null)
checkEvaluation(Overlay(Literal.create(null, StringType), Literal("ANSI "),
Literal.create(7, IntegerType), Literal.create(0, IntegerType)), null)
checkEvaluation(Overlay(Literal.create(null, StringType), Literal("tructured"),
Literal.create(2, IntegerType), Literal.create(4, IntegerType)), null)
// scalastyle:off
// non ascii characters are not allowed in the source code, so we disable the scalastyle.
checkEvaluation(new Overlay(Literal("Spark的SQL"), Literal("_"),
Literal.create(6, IntegerType)), "Spark_SQL")
// scalastyle:on
}

test("translate") {
checkEvaluation(
StringTranslate(Literal("translate"), Literal("rnlt"), Literal("123")), "1a2s3ae")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,35 @@ class PlanParserSuite extends AnalysisTest {
)
}

test("OVERLAY function") {
def assertOverlayPlans(inputSQL: String, expectedExpression: Expression): Unit = {
comparePlans(
parsePlan(inputSQL),
Project(Seq(UnresolvedAlias(expectedExpression)), OneRowRelation())
)
}

assertOverlayPlans(
"SELECT OVERLAY('Spark SQL' PLACING '_' FROM 6)",
new Overlay(Literal("Spark SQL"), Literal("_"), Literal(6))
)

assertOverlayPlans(
"SELECT OVERLAY('Spark SQL' PLACING 'CORE' FROM 7)",
new Overlay(Literal("Spark SQL"), Literal("CORE"), Literal(7))
)

assertOverlayPlans(
"SELECT OVERLAY('Spark SQL' PLACING 'ANSI ' FROM 7 FOR 0)",
Overlay(Literal("Spark SQL"), Literal("ANSI "), Literal(7), Literal(0))
)

assertOverlayPlans(
"SELECT OVERLAY('Spark SQL' PLACING 'tructured' FROM 2 FOR 4)",
Overlay(Literal("Spark SQL"), Literal("tructured"), Literal(2), Literal(4))
)
}

test("precedence of set operations") {
val a = table("a").select(star())
val b = table("b").select(star())
Expand Down
22 changes: 22 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2516,6 +2516,28 @@ object functions {
SubstringIndex(str.expr, lit(delim).expr, lit(count).expr)
}

/**
* Overlay the specified portion of `src` with `replaceString`,
* starting from byte position `pos` of `inputString` and proceeding for `len` bytes.
*
* @group string_funcs
* @since 3.0.0
*/
def overlay(src: Column, replaceString: String, pos: Int, len: Int): Column = withExpr {
Overlay(src.expr, lit(replaceString).expr, lit(pos).expr, lit(len).expr)
}

/**
* Overlay the specified portion of `src` with `replaceString`,
* starting from byte position `pos` of `inputString`.
*
* @group string_funcs
* @since 3.0.0
*/
def overlay(src: Column, replaceString: String, pos: Int): Column = withExpr {
new Overlay(src.expr, lit(replaceString).expr, lit(pos).expr)
}

/**
* Translate any character in the src by a character in replaceString.
* The characters in replaceString correspond to the characters in matchingString.
Expand Down
Loading

0 comments on commit 832ff87

Please sign in to comment.