Skip to content

Commit

Permalink
Fix ArrayIndexOutOfBoundsException on join counts with constant join …
Browse files Browse the repository at this point in the history
…keys (#11244)

* Fix ArrayIndexOutOfBoundsException on join counts with constant join keys

Signed-off-by: Jason Lowe <[email protected]>

* Handle GpuAlias

---------

Signed-off-by: Jason Lowe <[email protected]>
  • Loading branch information
jlowe authored Jul 26, 2024
1 parent 4bdbd19 commit 8e835e1
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
14 changes: 11 additions & 3 deletions integration_tests/src/main/python/join_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@

import pytest
from _pytest.mark.structures import ParameterSet
from pyspark.sql.functions import array_contains, broadcast, col
from pyspark.sql.functions import array_contains, broadcast, col, lit
from pyspark.sql.types import *
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_fallback_collect, assert_cpu_and_gpu_are_equal_collect_with_capture
from conftest import is_databricks_runtime, is_emr_runtime, is_not_utc
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_row_counts_equal, assert_gpu_fallback_collect, assert_cpu_and_gpu_are_equal_collect_with_capture
from conftest import is_emr_runtime
from data_gen import *
from marks import ignore_order, allow_non_gpu, incompat, validate_execs_in_gpu_plan
from spark_session import with_cpu_session, is_before_spark_330, is_databricks_runtime
Expand Down Expand Up @@ -164,6 +164,14 @@ def do_join(spark):
return left.join(right.hint("broadcast"), left.a == right.r_a, join_type)
assert_gpu_and_cpu_are_equal_collect(do_join, conf={'spark.sql.adaptive.enabled': 'true'})

@pytest.mark.parametrize('join_type', ['Left', 'Inner', 'LeftSemi', 'LeftAnti'], ids=idfn)
def test_broadcast_hash_join_constant_keys(join_type):
def do_join(spark):
left = spark.range(10).withColumn("s", lit(1))
right = spark.range(10000).withColumn("r_s", lit(1))
return left.join(right.hint("broadcast"), left.s == right.r_s, join_type)
assert_gpu_and_cpu_row_counts_equal(do_join, conf={'spark.sql.adaptive.enabled': 'true'})


# local sort because of https://github.com/NVIDIA/spark-rapids/issues/84
# After 3.1.0 is the min spark version we can drop this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,16 @@ abstract class GpuBroadcastNestedLoopJoinExecBase(
broadcastExchange.executeColumnarBroadcast[Any]()
}

private def isUnconditionalJoin(condition: Option[GpuExpression]): Boolean = {
condition.forall {
case GpuLiteral(true, BooleanType) =>
// Spark can generate a degenerate conditional join when the join keys are constants
output.isEmpty
case GpuAlias(e: GpuExpression, _) => isUnconditionalJoin(Some(e))
case _ => false
}
}

override def internalDoExecuteColumnar(): RDD[ColumnarBatch] = {
// Determine which table will be first in the join and bind the references accordingly
// so the AST column references match the appropriate table.
Expand All @@ -583,7 +593,9 @@ abstract class GpuBroadcastNestedLoopJoinExecBase(
if (useTrueCondition) Some(GpuLiteral(true)) else None
}

if (joinCondition.isEmpty) {
// Sometimes Spark specifies a true condition for a row-count-only join.
// This can happen when the join keys are detected to be constant.
if (isUnconditionalJoin(joinCondition)) {
doUnconditionalJoin(broadcastRelation)
} else {
doConditionalJoin(broadcastRelation, joinCondition, numFirstTableColumns)
Expand Down

0 comments on commit 8e835e1

Please sign in to comment.