diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala index 3c7b9ee317222..3a7fcf1fa9d89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala @@ -262,7 +262,9 @@ class SortBasedAggregator( // Firstly, update the aggregation buffer with input rows. while (hasNextInput && groupingKeyOrdering.compare(inputIterator.getKey, groupingKey) == 0) { - processRow(result.aggregationBuffer, inputIterator.getValue) + // Since `inputIterator.getValue` is an `UnsafeRow` whose underlying buffer will be + // overwritten when `inputIterator` steps forward, we need to do a deep copy here. + processRow(result.aggregationBuffer, inputIterator.getValue.copy()) hasNextInput = inputIterator.next() } @@ -271,7 +273,12 @@ class SortBasedAggregator( // be called after calling processRow. while (hasNextAggBuffer && groupingKeyOrdering.compare(initialAggBufferIterator.getKey, groupingKey) == 0) { - mergeAggregationBuffers(result.aggregationBuffer, initialAggBufferIterator.getValue) + mergeAggregationBuffers( + result.aggregationBuffer, + // Since `inputIterator.getValue` is an `UnsafeRow` whose underlying buffer will be + // overwritten when `inputIterator` steps forward, we need to do a deep copy here. + initialAggBufferIterator.getValue.copy() + ) hasNextAggBuffer = initialAggBufferIterator.next() } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala index b7f91d8c3a797..9a8d4498bba2f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala @@ -205,23 +205,19 @@ class ObjectHashAggregateSuite // A TypedImperativeAggregate function val typed = percentile_approx($"c0", 0.5) - // A Hive UDAF without partial aggregation support - val withoutPartial = function("hive_max", $"c1") - // A Spark SQL native aggregate function with partial aggregation support that can be executed // by the Tungsten `HashAggregateExec` - val withPartialUnsafe = max($"c2") + val withPartialUnsafe = max($"c1") // A Spark SQL native aggregate function with partial aggregation support that can only be // executed by the Tungsten `HashAggregateExec` - val withPartialSafe = max($"c3") + val withPartialSafe = max($"c2") // A Spark SQL native distinct aggregate function - val withDistinct = countDistinct($"c4") + val withDistinct = countDistinct($"c3") val allAggs = Seq( "typed" -> typed, - "without partial" -> withoutPartial, "with partial + unsafe" -> withPartialUnsafe, "with partial + safe" -> withPartialSafe, "with distinct" -> withDistinct @@ -276,10 +272,9 @@ class ObjectHashAggregateSuite // Generates a random schema for the randomized data generator val schema = new StructType() .add("c0", numericTypes(random.nextInt(numericTypes.length)), nullable = true) - .add("c1", orderedTypes(random.nextInt(orderedTypes.length)), nullable = true) - .add("c2", fixedLengthTypes(random.nextInt(fixedLengthTypes.length)), nullable = true) - .add("c3", varLenOrderedTypes(random.nextInt(varLenOrderedTypes.length)), nullable = true) - .add("c4", allTypes(random.nextInt(allTypes.length)), nullable = true) + .add("c1", fixedLengthTypes(random.nextInt(fixedLengthTypes.length)), nullable = true) + .add("c2", varLenOrderedTypes(random.nextInt(varLenOrderedTypes.length)), nullable = true) + .add("c3", allTypes(random.nextInt(allTypes.length)), nullable = true) logInfo( s"""Using the following random schema to generate all the randomized aggregation tests: @@ -325,70 +320,67 @@ class ObjectHashAggregateSuite // Currently Spark SQL doesn't support evaluating distinct aggregate function together // with aggregate functions without partial aggregation support. - if (!(aggs.contains(withoutPartial) && aggs.contains(withDistinct))) { - // TODO Re-enables them after fixing SPARK-18403 - ignore( - s"randomized aggregation test - " + - s"${names.mkString("[", ", ", "]")} - " + - s"${if (withGroupingKeys) "with" else "without"} grouping keys - " + - s"with ${if (emptyInput) "empty" else "non-empty"} input" - ) { - var expected: Seq[Row] = null - var actual1: Seq[Row] = null - var actual2: Seq[Row] = null - - // Disables `ObjectHashAggregateExec` to obtain a standard answer - withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "false") { - val aggDf = doAggregation(df) - - if (aggs.intersect(Seq(withoutPartial, withPartialSafe, typed)).nonEmpty) { - assert(containsSortAggregateExec(aggDf)) - assert(!containsObjectHashAggregateExec(aggDf)) - assert(!containsHashAggregateExec(aggDf)) - } else { - assert(!containsSortAggregateExec(aggDf)) - assert(!containsObjectHashAggregateExec(aggDf)) - assert(containsHashAggregateExec(aggDf)) - } - - expected = aggDf.collect().toSeq + test( + s"randomized aggregation test - " + + s"${names.mkString("[", ", ", "]")} - " + + s"${if (withGroupingKeys) "with" else "without"} grouping keys - " + + s"with ${if (emptyInput) "empty" else "non-empty"} input" + ) { + var expected: Seq[Row] = null + var actual1: Seq[Row] = null + var actual2: Seq[Row] = null + + // Disables `ObjectHashAggregateExec` to obtain a standard answer + withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "false") { + val aggDf = doAggregation(df) + + if (aggs.intersect(Seq(withPartialSafe, typed)).nonEmpty) { + assert(containsSortAggregateExec(aggDf)) + assert(!containsObjectHashAggregateExec(aggDf)) + assert(!containsHashAggregateExec(aggDf)) + } else { + assert(!containsSortAggregateExec(aggDf)) + assert(!containsObjectHashAggregateExec(aggDf)) + assert(containsHashAggregateExec(aggDf)) } - // Enables `ObjectHashAggregateExec` - withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "true") { - val aggDf = doAggregation(df) - - if (aggs.contains(typed) && !aggs.contains(withoutPartial)) { - assert(!containsSortAggregateExec(aggDf)) - assert(containsObjectHashAggregateExec(aggDf)) - assert(!containsHashAggregateExec(aggDf)) - } else if (aggs.intersect(Seq(withoutPartial, withPartialSafe)).nonEmpty) { - assert(containsSortAggregateExec(aggDf)) - assert(!containsObjectHashAggregateExec(aggDf)) - assert(!containsHashAggregateExec(aggDf)) - } else { - assert(!containsSortAggregateExec(aggDf)) - assert(!containsObjectHashAggregateExec(aggDf)) - assert(containsHashAggregateExec(aggDf)) - } - - // Disables sort-based aggregation fallback (we only generate 50 rows, so 100 is - // big enough) to obtain a result to be checked. - withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "100") { - actual1 = aggDf.collect().toSeq - } - - // Enables sort-based aggregation fallback to obtain another result to be checked. - withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "3") { - // Here we are not reusing `aggDf` because the physical plan in `aggDf` is - // cached and won't be re-planned using the new fallback threshold. - actual2 = doAggregation(df).collect().toSeq - } + expected = aggDf.collect().toSeq + } + + // Enables `ObjectHashAggregateExec` + withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "true") { + val aggDf = doAggregation(df) + + if (aggs.contains(typed)) { + assert(!containsSortAggregateExec(aggDf)) + assert(containsObjectHashAggregateExec(aggDf)) + assert(!containsHashAggregateExec(aggDf)) + } else if (aggs.contains(withPartialSafe)) { + assert(containsSortAggregateExec(aggDf)) + assert(!containsObjectHashAggregateExec(aggDf)) + assert(!containsHashAggregateExec(aggDf)) + } else { + assert(!containsSortAggregateExec(aggDf)) + assert(!containsObjectHashAggregateExec(aggDf)) + assert(containsHashAggregateExec(aggDf)) } - doubleSafeCheckRows(actual1, expected, 1e-4) - doubleSafeCheckRows(actual2, expected, 1e-4) + // Disables sort-based aggregation fallback (we only generate 50 rows, so 100 is + // big enough) to obtain a result to be checked. + withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "100") { + actual1 = aggDf.collect().toSeq + } + + // Enables sort-based aggregation fallback to obtain another result to be checked. + withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "3") { + // Here we are not reusing `aggDf` because the physical plan in `aggDf` is + // cached and won't be re-planned using the new fallback threshold. + actual2 = doAggregation(df).collect().toSeq + } } + + doubleSafeCheckRows(actual1, expected, 1e-4) + doubleSafeCheckRows(actual2, expected, 1e-4) } } } @@ -425,7 +417,35 @@ class ObjectHashAggregateSuite } } - private def function(name: String, args: Column*): Column = { - Column(UnresolvedFunction(FunctionIdentifier(name), args.map(_.expr), isDistinct = false)) + test("SPARK-18403 Fix unsafe data false sharing issue in ObjectHashAggregateExec") { + // SPARK-18403: An unsafe data false sharing issue may trigger OOM / SIGSEGV when evaluating + // certain aggregate functions. To reproduce this issue, the following conditions must be + // met: + // + // 1. The aggregation must be evaluated using `ObjectHashAggregateExec`; + // 2. There must be an input column whose data type involves `ArrayType` or `MapType`; + // 3. Sort-based aggregation fallback must be triggered during evaluation. + withSQLConf( + SQLConf.USE_OBJECT_HASH_AGG.key -> "true", + SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "1" + ) { + checkAnswer( + Seq + .fill(2)(Tuple1(Array.empty[Int])) + .toDF("c0") + .groupBy(lit(1)) + .agg(typed_count($"c0"), max($"c0")), + Row(1, 2, Array.empty[Int]) + ) + + checkAnswer( + Seq + .fill(2)(Tuple1(Map.empty[Int, Int])) + .toDF("c0") + .groupBy(lit(1)) + .agg(typed_count($"c0"), first($"c0")), + Row(1, 2, Map.empty[Int, Int]) + ) + } } }