Skip to content

Commit

Permalink
[SPARK-28989][SQL] Add a SQLConf spark.sql.ansi.enabled
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Currently, there are new configurations for compatibility with ANSI SQL:

* `spark.sql.parser.ansi.enabled`
* `spark.sql.decimalOperations.nullOnOverflow`
* `spark.sql.failOnIntegralTypeOverflow`
This PR is to add new configuration `spark.sql.ansi.enabled` and remove the 3 options above. When the configuration is true, Spark tries to conform to the ANSI SQL specification. It will be disabled by default.

### Why are the changes needed?

Make it simple and straightforward.

### Does this PR introduce any user-facing change?

The new features for ANSI compatibility will be set via one configuration `spark.sql.ansi.enabled`.

### How was this patch tested?

Existing unit tests.

Closes apache#25693 from gengliangwang/ansiEnabled.

Lead-authored-by: Gengliang Wang <[email protected]>
Co-authored-by: Xiao Li <[email protected]>
Signed-off-by: Xiao Li <[email protected]>
  • Loading branch information
gengliangwang and gatorsmile committed Sep 19, 2019
1 parent a6a663c commit b917a65
Show file tree
Hide file tree
Showing 29 changed files with 86 additions and 109 deletions.
8 changes: 4 additions & 4 deletions docs/sql-keywords.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ license: |
limitations under the License.
---

When `spark.sql.parser.ansi.enabled` is true, Spark SQL has two kinds of keywords:
When `spark.sql.ansi.enabled` is true, Spark SQL has two kinds of keywords:
* Reserved keywords: Keywords that are reserved and can't be used as identifiers for table, view, column, function, alias, etc.
* Non-reserved keywords: Keywords that have a special meaning only in particular contexts and can be used as identifiers in other contexts. For example, `SELECT 1 WEEK` is an interval literal, but WEEK can be used as identifiers in other places.

When `spark.sql.parser.ansi.enabled` is false, Spark SQL has two kinds of keywords:
* Non-reserved keywords: Same definition as the one when `spark.sql.parser.ansi.enabled=true`.
When `spark.sql.ansi.enabled` is false, Spark SQL has two kinds of keywords:
* Non-reserved keywords: Same definition as the one when `spark.sql.ansi.enabled=true`.
* Strict-non-reserved keywords: A strict version of non-reserved keywords, which can not be used as table alias.

By default `spark.sql.parser.ansi.enabled` is false.
By default `spark.sql.ansi.enabled` is false.

Below is a list of all the keywords in Spark SQL.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ object CatalystTypeConverters {
private class DecimalConverter(dataType: DecimalType)
extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] {

private val nullOnOverflow = SQLConf.get.decimalOperationsNullOnOverflow
private val nullOnOverflow = !SQLConf.get.ansiEnabled

override def toCatalystImpl(scalaValue: Any): Decimal = {
val decimal = scalaValue match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.unsafe.types.UTF8String

object SerializerBuildHelper {

private def nullOnOverflow: Boolean = SQLConf.get.decimalOperationsNullOnOverflow
private def nullOnOverflow: Boolean = !SQLConf.get.ansiEnabled

def createSerializerForBoolean(inputObject: Expression): Expression = {
Invoke(inputObject, "booleanValue", BooleanType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ object DecimalPrecision extends TypeCoercionRule {
PromotePrecision(Cast(e, dataType))
}

private def nullOnOverflow: Boolean = SQLConf.get.decimalOperationsNullOnOverflow
private def nullOnOverflow: Boolean = !SQLConf.get.ansiEnabled

override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
// fix decimal precision for expressions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ object RowEncoder {
d,
"fromDecimal",
inputObject :: Nil,
returnNullable = false), d, SQLConf.get.decimalOperationsNullOnOverflow)
returnNullable = false), d, !SQLConf.get.ansiEnabled)

case StringType => createSerializerForString(inputObject)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String

private lazy val dateFormatter = DateFormatter()
private lazy val timestampFormatter = TimestampFormatter.getFractionFormatter(zoneId)
private val failOnIntegralTypeOverflow = SQLConf.get.failOnIntegralTypeOverflow
private val failOnIntegralTypeOverflow = SQLConf.get.ansiEnabled

// UDFToString
private[this] def castToString(from: DataType): Any => Any = from match {
Expand Down Expand Up @@ -600,13 +600,13 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte
}

private val nullOnOverflow = SQLConf.get.decimalOperationsNullOnOverflow
private val nullOnOverflow = !SQLConf.get.ansiEnabled

/**
* Change the precision / scale in a given decimal to those set in `decimalType` (if any),
* modifying `value` in-place and returning it if successful. If an overflow occurs, it
* either returns null or throws an exception according to the value set for
* `spark.sql.decimalOperations.nullOnOverflow`.
* `spark.sql.ansi.enabled`.
*
* NOTE: this modifies `value` in-place, so don't call it on external data.
*/
Expand All @@ -625,7 +625,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String

/**
* Create new `Decimal` with precision and scale given in `decimalType` (if any).
* If overflow occurs, if `spark.sql.decimalOperations.nullOnOverflow` is true, null is returned;
* If overflow occurs, if `spark.sql.ansi.enabled` is false, null is returned;
* otherwise, an `ArithmeticException` is thrown.
*/
private[this] def toPrecision(value: Decimal, decimalType: DecimalType): Decimal =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
}

override lazy val evaluateExpression: Expression = resultType match {
case d: DecimalType => CheckOverflow(sum, d, SQLConf.get.decimalOperationsNullOnOverflow)
case d: DecimalType => CheckOverflow(sum, d, !SQLConf.get.ansiEnabled)
case _ => sum
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.unsafe.types.CalendarInterval
""")
case class UnaryMinus(child: Expression) extends UnaryExpression
with ExpectsInputTypes with NullIntolerant {
private val checkOverflow = SQLConf.get.failOnIntegralTypeOverflow
private val checkOverflow = SQLConf.get.ansiEnabled

override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)

Expand Down Expand Up @@ -136,7 +136,7 @@ case class Abs(child: Expression)

abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {

protected val checkOverflow = SQLConf.get.failOnIntegralTypeOverflow
protected val checkOverflow = SQLConf.get.ansiEnabled

override def dataType: DataType = left.dataType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ case class UnscaledValue(child: Expression) extends UnaryExpression {
*/
case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends UnaryExpression {

private val nullOnOverflow = SQLConf.get.decimalOperationsNullOnOverflow
private val nullOnOverflow = !SQLConf.get.ansiEnabled

override def dataType: DataType = DecimalType(precision, scale)
override def nullable: Boolean = child.nullable || nullOnOverflow
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1363,7 +1363,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
}

override def visitCurrentDatetime(ctx: CurrentDatetimeContext): Expression = withOrigin(ctx) {
if (conf.ansiParserEnabled) {
if (conf.ansiEnabled) {
ctx.name.getType match {
case SqlBaseParser.CURRENT_DATE =>
CurrentDate()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,15 @@ abstract class AbstractSqlParser extends ParserInterface with Logging {
lexer.removeErrorListeners()
lexer.addErrorListener(ParseErrorListener)
lexer.legacy_setops_precedence_enbled = SQLConf.get.setOpsPrecedenceEnforced
lexer.ansi = SQLConf.get.ansiParserEnabled
lexer.ansi = SQLConf.get.ansiEnabled

val tokenStream = new CommonTokenStream(lexer)
val parser = new SqlBaseParser(tokenStream)
parser.addParseListener(PostProcessor)
parser.removeErrorListeners()
parser.addErrorListener(ParseErrorListener)
parser.legacy_setops_precedence_enbled = SQLConf.get.setOpsPrecedenceEnforced
parser.ansi = SQLConf.get.ansiParserEnabled
parser.ansi = SQLConf.get.ansiEnabled

try {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -411,12 +411,6 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val ANSI_SQL_PARSER =
buildConf("spark.sql.parser.ansi.enabled")
.doc("When true, tries to conform to ANSI SQL syntax.")
.booleanConf
.createWithDefault(false)

val ESCAPED_STRING_LITERALS = buildConf("spark.sql.parser.escapedStringLiterals")
.internal()
.doc("When true, string literals (including regex patterns) remain escaped in our SQL " +
Expand Down Expand Up @@ -1557,16 +1551,6 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val DECIMAL_OPERATIONS_NULL_ON_OVERFLOW =
buildConf("spark.sql.decimalOperations.nullOnOverflow")
.internal()
.doc("When true (default), if an overflow on a decimal occurs, then NULL is returned. " +
"Spark's older versions and Hive behave in this way. If turned to false, SQL ANSI 2011 " +
"specification will be followed instead: an arithmetic exception is thrown, as most " +
"of the SQL databases do.")
.booleanConf
.createWithDefault(true)

val LITERAL_PICK_MINIMUM_PRECISION =
buildConf("spark.sql.legacy.literal.pickMinimumPrecision")
.internal()
Expand Down Expand Up @@ -1723,6 +1707,14 @@ object SQLConf {
.checkValues(StoreAssignmentPolicy.values.map(_.toString))
.createOptional

val ANSI_ENABLED = buildConf("spark.sql.ansi.enabled")
.doc("When true, Spark tries to conform to the ANSI SQL specification: 1. Spark will " +
"throw a runtime exception if an overflow occurs in any operation on integral/decimal " +
"field. 2. Spark will forbid using the reserved keywords of ANSI SQL as identifiers in " +
"the SQL parser.")
.booleanConf
.createWithDefault(false)

val SORT_BEFORE_REPARTITION =
buildConf("spark.sql.execution.sortBeforeRepartition")
.internal()
Expand Down Expand Up @@ -1886,15 +1878,6 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val FAIL_ON_INTEGRAL_TYPE_OVERFLOW =
buildConf("spark.sql.failOnIntegralTypeOverflow")
.doc("If it is set to true, all operations on integral fields throw an " +
"exception if an overflow occurs. If it is false (default), in case of overflow a wrong " +
"result is returned.")
.internal()
.booleanConf
.createWithDefault(false)

val LEGACY_HAVING_WITHOUT_GROUP_BY_AS_WHERE =
buildConf("spark.sql.legacy.parser.havingWithoutGroupByAsWhere")
.internal()
Expand Down Expand Up @@ -2195,8 +2178,6 @@ class SQLConf extends Serializable with Logging {

def constraintPropagationEnabled: Boolean = getConf(CONSTRAINT_PROPAGATION_ENABLED)

def ansiParserEnabled: Boolean = getConf(ANSI_SQL_PARSER)

def escapedStringLiterals: Boolean = getConf(ESCAPED_STRING_LITERALS)

def fileCompressionFactor: Double = getConf(FILE_COMPRESSION_FACTOR)
Expand Down Expand Up @@ -2418,10 +2399,6 @@ class SQLConf extends Serializable with Logging {

def decimalOperationsAllowPrecisionLoss: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS)

def decimalOperationsNullOnOverflow: Boolean = getConf(DECIMAL_OPERATIONS_NULL_ON_OVERFLOW)

def failOnIntegralTypeOverflow: Boolean = getConf(FAIL_ON_INTEGRAL_TYPE_OVERFLOW)

def literalPickMinimumPrecision: Boolean = getConf(LITERAL_PICK_MINIMUM_PRECISION)

def continuousStreamingEpochBacklogQueueSize: Int =
Expand Down Expand Up @@ -2454,6 +2431,8 @@ class SQLConf extends Serializable with Logging {
def storeAssignmentPolicy: Option[StoreAssignmentPolicy.Value] =
getConf(STORE_ASSIGNMENT_POLICY).map(StoreAssignmentPolicy.withName)

def ansiEnabled: Boolean = getConf(ANSI_ENABLED)

def nestedSchemaPruningEnabled: Boolean = getConf(NESTED_SCHEMA_PRUNING_ENABLED)

def serializerNestedSchemaPruningEnabled: Boolean =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -432,16 +432,16 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes
}

private def testOverflowingBigNumeric[T: TypeTag](bigNumeric: T, testName: String): Unit = {
Seq(true, false).foreach { allowNullOnOverflow =>
Seq(true, false).foreach { ansiEnabled =>
testAndVerifyNotLeakingReflectionObjects(
s"overflowing $testName, allowNullOnOverflow=$allowNullOnOverflow") {
s"overflowing $testName, ansiEnabled=$ansiEnabled") {
withSQLConf(
SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key -> allowNullOnOverflow.toString
SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString
) {
// Need to construct Encoder here rather than implicitly resolving it
// so that SQLConf changes are respected.
val encoder = ExpressionEncoder[T]()
if (allowNullOnOverflow) {
if (!ansiEnabled) {
val convertedBack = encoder.resolveAndBind().fromRow(encoder.toRow(bigNumeric))
assert(convertedBack === null)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
}

private def testDecimalOverflow(schema: StructType, row: Row): Unit = {
withSQLConf(SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key -> "false") {
withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
val encoder = RowEncoder(schema).resolveAndBind()
intercept[Exception] {
encoder.toRow(row)
Expand All @@ -182,7 +182,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
}
}

withSQLConf(SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key -> "true") {
withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") {
val encoder = RowEncoder(schema).resolveAndBind()
assert(encoder.fromRow(encoder.toRow(row)).get(0) == null)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(Add(positiveLongLit, negativeLongLit), -1L)

Seq("true", "false").foreach { checkOverflow =>
withSQLConf(SQLConf.FAIL_ON_INTEGRAL_TYPE_OVERFLOW.key -> checkOverflow) {
withSQLConf(SQLConf.ANSI_ENABLED.key -> checkOverflow) {
DataTypeTestUtils.numericAndInterval.foreach { tpe =>
checkConsistencyBetweenInterpretedAndCodegenAllowingException(Add, tpe, tpe)
}
Expand All @@ -80,7 +80,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(UnaryMinus(Literal(Int.MinValue)), Int.MinValue)
checkEvaluation(UnaryMinus(Literal(Short.MinValue)), Short.MinValue)
checkEvaluation(UnaryMinus(Literal(Byte.MinValue)), Byte.MinValue)
withSQLConf(SQLConf.FAIL_ON_INTEGRAL_TYPE_OVERFLOW.key -> "true") {
withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
checkExceptionInExpression[ArithmeticException](
UnaryMinus(Literal(Long.MinValue)), "overflow")
checkExceptionInExpression[ArithmeticException](
Expand Down Expand Up @@ -122,7 +122,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(Subtract(positiveLongLit, negativeLongLit), positiveLong - negativeLong)

Seq("true", "false").foreach { checkOverflow =>
withSQLConf(SQLConf.FAIL_ON_INTEGRAL_TYPE_OVERFLOW.key -> checkOverflow) {
withSQLConf(SQLConf.ANSI_ENABLED.key -> checkOverflow) {
DataTypeTestUtils.numericAndInterval.foreach { tpe =>
checkConsistencyBetweenInterpretedAndCodegenAllowingException(Subtract, tpe, tpe)
}
Expand All @@ -144,7 +144,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(Multiply(positiveLongLit, negativeLongLit), positiveLong * negativeLong)

Seq("true", "false").foreach { checkOverflow =>
withSQLConf(SQLConf.FAIL_ON_INTEGRAL_TYPE_OVERFLOW.key -> checkOverflow) {
withSQLConf(SQLConf.ANSI_ENABLED.key -> checkOverflow) {
DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe =>
checkConsistencyBetweenInterpretedAndCodegenAllowingException(Multiply, tpe, tpe)
}
Expand Down Expand Up @@ -445,12 +445,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
val e4 = Add(minLongLiteral, minLongLiteral)
val e5 = Subtract(minLongLiteral, maxLongLiteral)
val e6 = Multiply(minLongLiteral, minLongLiteral)
withSQLConf(SQLConf.FAIL_ON_INTEGRAL_TYPE_OVERFLOW.key -> "true") {
withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
Seq(e1, e2, e3, e4, e5, e6).foreach { e =>
checkExceptionInExpression[ArithmeticException](e, "overflow")
}
}
withSQLConf(SQLConf.FAIL_ON_INTEGRAL_TYPE_OVERFLOW.key -> "false") {
withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") {
checkEvaluation(e1, Long.MinValue)
checkEvaluation(e2, Long.MinValue)
checkEvaluation(e3, -2L)
Expand All @@ -469,12 +469,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
val e4 = Add(minIntLiteral, minIntLiteral)
val e5 = Subtract(minIntLiteral, maxIntLiteral)
val e6 = Multiply(minIntLiteral, minIntLiteral)
withSQLConf(SQLConf.FAIL_ON_INTEGRAL_TYPE_OVERFLOW.key -> "true") {
withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
Seq(e1, e2, e3, e4, e5, e6).foreach { e =>
checkExceptionInExpression[ArithmeticException](e, "overflow")
}
}
withSQLConf(SQLConf.FAIL_ON_INTEGRAL_TYPE_OVERFLOW.key -> "false") {
withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") {
checkEvaluation(e1, Int.MinValue)
checkEvaluation(e2, Int.MinValue)
checkEvaluation(e3, -2)
Expand All @@ -493,12 +493,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
val e4 = Add(minShortLiteral, minShortLiteral)
val e5 = Subtract(minShortLiteral, maxShortLiteral)
val e6 = Multiply(minShortLiteral, minShortLiteral)
withSQLConf(SQLConf.FAIL_ON_INTEGRAL_TYPE_OVERFLOW.key -> "true") {
withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
Seq(e1, e2, e3, e4, e5, e6).foreach { e =>
checkExceptionInExpression[ArithmeticException](e, "overflow")
}
}
withSQLConf(SQLConf.FAIL_ON_INTEGRAL_TYPE_OVERFLOW.key -> "false") {
withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") {
checkEvaluation(e1, Short.MinValue)
checkEvaluation(e2, Short.MinValue)
checkEvaluation(e3, (-2).toShort)
Expand All @@ -517,12 +517,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
val e4 = Add(minByteLiteral, minByteLiteral)
val e5 = Subtract(minByteLiteral, maxByteLiteral)
val e6 = Multiply(minByteLiteral, minByteLiteral)
withSQLConf(SQLConf.FAIL_ON_INTEGRAL_TYPE_OVERFLOW.key -> "true") {
withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
Seq(e1, e2, e3, e4, e5, e6).foreach { e =>
checkExceptionInExpression[ArithmeticException](e, "overflow")
}
}
withSQLConf(SQLConf.FAIL_ON_INTEGRAL_TYPE_OVERFLOW.key -> "false") {
withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") {
checkEvaluation(e1, Byte.MinValue)
checkEvaluation(e2, Byte.MinValue)
checkEvaluation(e3, (-2).toByte)
Expand Down
Loading

0 comments on commit b917a65

Please sign in to comment.