Skip to content

Commit

Permalink
[SPARK-41413][SQL] Avoid shuffle in Storage-Partitioned Join when par…
Browse files Browse the repository at this point in the history
…tition keys mismatch, but join expressions are compatible

### What changes were proposed in this pull request?

This enhances Storage Partitioned Join by handling mismatch partition keys from both sides of the join and skip shuffle in certain cases.

### Why are the changes needed?

Currently in Storage Partitioned Join, when the partition transform expressions match, but the partition keys don't, we'd still fallback to shuffle. This is not necessary since we can find out the common set of partition keys and populate that to the scan nodes. On the scan node, those missing partition keys can be filled with empty partitions.

The above scenario is pretty common for `MERGE INTO` queries, as the changing data to be merged into the base table often need to be applied to new partitions. The current implementation will cause these queries to trigger shuffle and thus become expensive.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Added a few new tests in `KeyGroupedPartitioningSuite`.

Closes apache#38950 from sunchao/SPARK-41413.

Authored-by: Chao Sun <[email protected]>
Signed-off-by: Chao Sun <[email protected]>
  • Loading branch information
sunchao committed Dec 22, 2022
1 parent b3276ef commit ebbe0b0
Show file tree
Hide file tree
Showing 12 changed files with 310 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class AvroRowReaderSuite

val df = spark.read.format("avro").load(dir.getCanonicalPath)
val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: AvroScan, _, _, _, _) => f
case BatchScanExec(_, f: AvroScan, _, _, _, _, _) => f
}
val filePath = fileScan.get.fileIndex.inputFiles(0)
val fileSize = new File(new URI(filePath)).length
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2350,7 +2350,7 @@ class AvroV2Suite extends AvroSuite with ExplainSuiteHelper {
})

val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: AvroScan, _, _, _, _) => f
case BatchScanExec(_, f: AvroScan, _, _, _, _, _) => f
}
assert(fileScan.nonEmpty)
assert(fileScan.get.partitionFilters.nonEmpty)
Expand Down Expand Up @@ -2383,7 +2383,7 @@ class AvroV2Suite extends AvroSuite with ExplainSuiteHelper {
assert(filterCondition.isDefined)

val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: AvroScan, _, _, _, _) => f
case BatchScanExec(_, f: AvroScan, _, _, _, _, _) => f
}
assert(fileScan.nonEmpty)
assert(fileScan.get.partitionFilters.isEmpty)
Expand Down Expand Up @@ -2464,7 +2464,7 @@ class AvroV2Suite extends AvroSuite with ExplainSuiteHelper {
.where("value = 'a'")

val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: AvroScan, _, _, _, _) => f
case BatchScanExec(_, f: AvroScan, _, _, _, _, _) => f
}
assert(fileScan.nonEmpty)
if (filtersPushdown) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ case class KeyGroupedShuffleSpec(
}
}

private lazy val ordering: Ordering[InternalRow] =
lazy val ordering: Ordering[InternalRow] =
RowOrdering.createNaturalAscendingOrdering(partitioning.expressions.map(_.dataType))

override def numPartitions: Int = partitioning.numPartitions
Expand All @@ -694,28 +694,34 @@ case class KeyGroupedShuffleSpec(
// transform functions.
// 4. the partition values, if present on both sides, are following the same order.
case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution) =>
val expressions = partitioning.expressions
val otherExpressions = otherPartitioning.expressions

distribution.clustering.length == otherDistribution.clustering.length &&
numPartitions == other.numPartitions &&
expressions.length == otherExpressions.length && {
val otherKeyPositions = otherSpec.keyPositions
keyPositions.zip(otherKeyPositions).forall { case (left, right) =>
left.intersect(right).nonEmpty
}
} && expressions.zip(otherExpressions).forall {
case (l, r) => isExpressionCompatible(l, r)
} && partitioning.partitionValuesOpt.zip(otherPartitioning.partitionValuesOpt).forall {
numPartitions == other.numPartitions && areKeysCompatible(otherSpec) &&
partitioning.partitionValuesOpt.zip(otherPartitioning.partitionValuesOpt).forall {
case (left, right) => left.zip(right).forall { case (l, r) =>
ordering.compare(l, r) == 0
}
}
}
case ShuffleSpecCollection(specs) =>
specs.exists(isCompatibleWith)
case _ => false
}

// Whether the partition keys (i.e., partition expressions) are compatible between this and the
// `other` spec.
def areKeysCompatible(other: KeyGroupedShuffleSpec): Boolean = {
val expressions = partitioning.expressions
val otherExpressions = other.partitioning.expressions

expressions.length == otherExpressions.length && {
val otherKeyPositions = other.keyPositions
keyPositions.zip(otherKeyPositions).forall { case (left, right) =>
left.intersect(right).nonEmpty
}
} && expressions.zip(otherExpressions).forall {
case (l, r) => isExpressionCompatible(l, r)
}
}

private def isExpressionCompatible(left: Expression, right: Expression): Boolean =
(left, right) match {
case (_: LeafExpression, _: LeafExpression) => true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1405,6 +1405,18 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val V2_BUCKETING_PUSH_PART_VALUES_ENABLED =
buildConf("spark.sql.sources.v2.bucketing.pushPartValues.enabled")
.doc(s"Whether to pushdown common partition values when ${V2_BUCKETING_ENABLED.key} is " +
"enabled. When turned on, if both sides of a join are of KeyGroupedPartitioning and if " +
"they share compatible partition keys, even if they don't have the exact same partition " +
"values, Spark will calculate a superset of partition values and pushdown that info to " +
"scan nodes, which will use empty partitions for the missing partition values on either " +
"side. This could help to eliminate unnecessary shuffles")
.version("3.4.0")
.booleanConf
.createWithDefault(false)

val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets")
.doc("The maximum number of buckets allowed.")
.version("2.4.0")
Expand Down Expand Up @@ -4534,6 +4546,9 @@ class SQLConf extends Serializable with Logging {

def v2BucketingEnabled: Boolean = getConf(SQLConf.V2_BUCKETING_ENABLED)

def v2BucketingPushPartValuesEnabled: Boolean =
getConf(SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED)

def dataFrameSelfJoinAutoResolveAmbiguity: Boolean =
getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, SinglePartition}
import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, Partitioning, SinglePartition}
import org.apache.spark.sql.catalyst.util.InternalRowSet
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.connector.catalog.Table
Expand All @@ -39,14 +39,16 @@ case class BatchScanExec(
runtimeFilters: Seq[Expression],
keyGroupedPartitioning: Option[Seq[Expression]] = None,
ordering: Option[Seq[SortOrder]] = None,
@transient table: Table) extends DataSourceV2ScanExecBase {
@transient table: Table,
commonPartitionValues: Option[Seq[InternalRow]] = None) extends DataSourceV2ScanExecBase {

@transient lazy val batch = scan.toBatch

// TODO: unify the equal/hashCode implementation for all data source v2 query plans.
override def equals(other: Any): Boolean = other match {
case other: BatchScanExec =>
this.batch == other.batch && this.runtimeFilters == other.runtimeFilters
this.batch == other.batch && this.runtimeFilters == other.runtimeFilters &&
this.commonPartitionValues == other.commonPartitionValues
case _ =>
false
}
Expand Down Expand Up @@ -110,6 +112,15 @@ case class BatchScanExec(
}
}

override def outputPartitioning: Partitioning = {
super.outputPartitioning match {
case k: KeyGroupedPartitioning if commonPartitionValues.isDefined =>
val values = commonPartitionValues.get
k.copy(numPartitions = values.length, partitionValuesOpt = Some(values))
case p => p
}
}

override lazy val readerFactory: PartitionReaderFactory = batch.createReaderFactory()

override lazy val inputRDD: RDD[InternalRow] = {
Expand All @@ -123,9 +134,9 @@ case class BatchScanExec(
case p: KeyGroupedPartitioning =>
val partitionMapping = finalPartitions.map(s =>
s.head.asInstanceOf[HasPartitionKey].partitionKey() -> s).toMap
finalPartitions = p.partitionValuesOpt.get.map { partKey =>
// Use empty partition for those partition keys that are not present
partitionMapping.getOrElse(partKey, Seq.empty)
finalPartitions = p.partitionValuesOpt.get.map { partValue =>
// Use empty partition for those partition values that are not present
partitionMapping.getOrElse(partValue, Seq.empty)
}
case _ =>
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@ package org.apache.spark.sql.execution.exchange
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.collection.Utils

/**
* Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]]
Expand Down Expand Up @@ -142,12 +145,81 @@ case class EnsureRequirements(
Some(finalCandidateSpecs.values.maxBy(_.numPartitions))
}

// Check if 1) all children are of `KeyGroupedPartitioning` and 2) they are all compatible
// with each other. If both are true, skip shuffle.
val allCompatible = childrenIndexes.sliding(2).forall {
case Seq(a, b) =>
checkKeyGroupedSpec(specs(a)) && checkKeyGroupedSpec(specs(b)) &&
specs(a).isCompatibleWith(specs(b))
// Retrieve the non-collection spec from the input
def getRootSpec(spec: ShuffleSpec): ShuffleSpec = spec match {
case ShuffleSpecCollection(specs) => getRootSpec(specs.head)
case spec => spec
}

// Populate the common partition values down to the scan nodes
def populatePartitionValues(plan: SparkPlan, values: Seq[InternalRow]): SparkPlan =
plan match {
case scan: BatchScanExec =>
scan.copy(commonPartitionValues = Some(values))
case node =>
node.mapChildren(child => populatePartitionValues(child, values))
}

// Check if the following conditions are satisfied:
// 1. There are exactly two children (e.g., join). Note that Spark doesn't support
// multi-way join at the moment, so this check should be sufficient.
// 2. All children are of `KeyGroupedPartitioning`, and they are compatible with each other
// If both are true, skip shuffle.
val allCompatible = childrenIndexes.length == 2 && {
val left = childrenIndexes.head
val right = childrenIndexes(1)
var isCompatible: Boolean = false

if (checkKeyGroupedSpec(specs(left)) && checkKeyGroupedSpec(specs(right))) {
isCompatible = specs(left).isCompatibleWith(specs(right))

// If `isCompatible` is false, it could mean:
// 1. Partition keys (expressions) are not compatible: we have to shuffle in this case.
// 2. Partition keys (expressions) are compatible, but partition values are not: in this
// case we can compute a superset of partition values and push-down to respective
// data sources, which can then adjust their respective output partitioning by
// filling missing partition values with empty partitions. As result, Spark can still
// avoid shuffle.
//
// For instance, if two sides of a join have partition expressions `day(a)` and `day(b)`
// respectively (the join query could be `SELECT ... FROM t1 JOIN t2 on t1.a = t2.b`),
// but with different partition values:
// `day(a)`: [0, 1]
// `day(b)`: [1, 2, 3]
// Following the case 2 above, we don't have to shuffle both sides, but instead can just
// push the common set of partition values: `[0, 1, 2, 3]` down to the two data sources.
if (!isCompatible && conf.v2BucketingPushPartValuesEnabled) {
(getRootSpec(specs(left)), getRootSpec(specs(right))) match {
case (leftSpec: KeyGroupedShuffleSpec, rightSpec: KeyGroupedShuffleSpec) =>
// Check if the two children are partition keys compatible. If so, find the
// common set of partition values, and adjust the plan accordingly.
if (leftSpec.areKeysCompatible(rightSpec)) {
assert(leftSpec.partitioning.partitionValuesOpt.isDefined)
assert(rightSpec.partitioning.partitionValuesOpt.isDefined)

val leftPartValues = leftSpec.partitioning.partitionValuesOpt.get
val rightPartValues = rightSpec.partitioning.partitionValuesOpt.get

val mergedPartValues = Utils.mergeOrdered(
Seq(leftPartValues, rightPartValues))(leftSpec.ordering).toSeq.distinct

// Now we need to push-down the common partition key to the scan in each child
children = children.zipWithIndex.map {
case (child, idx) if childrenIndexes.contains(idx) =>
populatePartitionValues(child, mergedPartValues)
case (child, _) => child
}

isCompatible = true
}
case _ =>
// This case shouldn't happen since `checkKeyGroupedSpec` should've made
// sure that we only have `KeyGroupedShuffleSpec`
}
}
}

isCompatible
}

children = children.zip(requiredChildDistributions).zipWithIndex.map {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -861,7 +861,7 @@ class FileBasedDataSourceSuite extends QueryTest
})

val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: FileScan, _, _, _, _) => f
case BatchScanExec(_, f: FileScan, _, _, _, _, _) => f
}
assert(fileScan.nonEmpty)
assert(fileScan.get.partitionFilters.nonEmpty)
Expand Down Expand Up @@ -901,7 +901,7 @@ class FileBasedDataSourceSuite extends QueryTest
assert(filterCondition.isDefined)

val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: FileScan, _, _, _, _) => f
case BatchScanExec(_, f: FileScan, _, _, _, _, _) => f
}
assert(fileScan.nonEmpty)
assert(fileScan.get.partitionFilters.isEmpty)
Expand Down
Loading

0 comments on commit ebbe0b0

Please sign in to comment.