Skip to content

Commit

Permalink
[SPARK-12656] [SQL] Implement Intersect with Left-semi Join
Browse files Browse the repository at this point in the history
Our current Intersect physical operator simply delegates to RDD.intersect. We should remove the Intersect physical operator and simply transform a logical intersect into a semi-join with distinct. This way, we can take advantage of all the benefits of join implementations (e.g. managed memory, code generation, broadcast joins).

After a search, I found one of the mainstream RDBMS did the same. In their query explain, Intersect is replaced by Left-semi Join. Left-semi Join could help outer-join elimination in Optimizer, as shown in the PR: apache#10566

Author: gatorsmile <[email protected]>
Author: xiaoli <[email protected]>
Author: Xiao Li <[email protected]>

Closes apache#10630 from gatorsmile/IntersectBySemiJoin.
  • Loading branch information
gatorsmile authored and rxin committed Jan 29, 2016
1 parent c5f745e commit 5f686cc
Show file tree
Hide file tree
Showing 11 changed files with 211 additions and 122 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,63 @@ class Analyzer(
}
}

/**
* Generate a new logical plan for the right child with different expression IDs
* for all conflicting attributes.
*/
private def dedupRight (left: LogicalPlan, right: LogicalPlan): LogicalPlan = {
val conflictingAttributes = left.outputSet.intersect(right.outputSet)
logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} " +
s"between $left and $right")

right.collect {
// Handle base relations that might appear more than once.
case oldVersion: MultiInstanceRelation
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
val newVersion = oldVersion.newInstance()
(oldVersion, newVersion)

// Handle projects that create conflicting aliases.
case oldVersion @ Project(projectList, _)
if findAliases(projectList).intersect(conflictingAttributes).nonEmpty =>
(oldVersion, oldVersion.copy(projectList = newAliases(projectList)))

case oldVersion @ Aggregate(_, aggregateExpressions, _)
if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
(oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions)))

case oldVersion: Generate
if oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty =>
val newOutput = oldVersion.generatorOutput.map(_.newInstance())
(oldVersion, oldVersion.copy(generatorOutput = newOutput))

case oldVersion @ Window(_, windowExpressions, _, _, child)
if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes)
.nonEmpty =>
(oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions)))
}
// Only handle first case, others will be fixed on the next pass.
.headOption match {
case None =>
/*
* No result implies that there is a logical plan node that produces new references
* that this rule cannot handle. When that is the case, there must be another rule
* that resolves these conflicts. Otherwise, the analysis will fail.
*/
right
case Some((oldRelation, newRelation)) =>
val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output))
val newRight = right transformUp {
case r if r == oldRelation => newRelation
} transformUp {
case other => other transformExpressions {
case a: Attribute => attributeRewrites.get(a).getOrElse(a)
}
}
newRight
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case p: LogicalPlan if !p.childrenResolved => p

Expand Down Expand Up @@ -388,57 +445,11 @@ class Analyzer(
.map(_.asInstanceOf[NamedExpression])
a.copy(aggregateExpressions = expanded)

// Special handling for cases when self-join introduce duplicate expression ids.
case j @ Join(left, right, _, _) if !j.selfJoinResolved =>
val conflictingAttributes = left.outputSet.intersect(right.outputSet)
logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} in $j")

right.collect {
// Handle base relations that might appear more than once.
case oldVersion: MultiInstanceRelation
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
val newVersion = oldVersion.newInstance()
(oldVersion, newVersion)

// Handle projects that create conflicting aliases.
case oldVersion @ Project(projectList, _)
if findAliases(projectList).intersect(conflictingAttributes).nonEmpty =>
(oldVersion, oldVersion.copy(projectList = newAliases(projectList)))

case oldVersion @ Aggregate(_, aggregateExpressions, _)
if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
(oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions)))

case oldVersion: Generate
if oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty =>
val newOutput = oldVersion.generatorOutput.map(_.newInstance())
(oldVersion, oldVersion.copy(generatorOutput = newOutput))

case oldVersion @ Window(_, windowExpressions, _, _, child)
if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes)
.nonEmpty =>
(oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions)))
}
// Only handle first case, others will be fixed on the next pass.
.headOption match {
case None =>
/*
* No result implies that there is a logical plan node that produces new references
* that this rule cannot handle. When that is the case, there must be another rule
* that resolves these conflicts. Otherwise, the analysis will fail.
*/
j
case Some((oldRelation, newRelation)) =>
val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output))
val newRight = right transformUp {
case r if r == oldRelation => newRelation
} transformUp {
case other => other transformExpressions {
case a: Attribute => attributeRewrites.get(a).getOrElse(a)
}
}
j.copy(right = newRight)
}
// To resolve duplicate expression IDs for Join and Intersect
case j @ Join(left, right, _, _) if !j.duplicateResolved =>
j.copy(right = dedupRight(left, right))
case i @ Intersect(left, right) if !i.duplicateResolved =>
i.copy(right = dedupRight(left, right))

// When resolve `SortOrder`s in Sort based on child, don't report errors as
// we still have chance to resolve it based on grandchild
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,16 +214,24 @@ trait CheckAnalysis {
s"""Only a single table generating function is allowed in a SELECT clause, found:
| ${exprs.map(_.prettyString).mkString(",")}""".stripMargin)

// Special handling for cases when self-join introduce duplicate expression ids.
case j @ Join(left, right, _, _) if left.outputSet.intersect(right.outputSet).nonEmpty =>
val conflictingAttributes = left.outputSet.intersect(right.outputSet)
case j: Join if !j.duplicateResolved =>
val conflictingAttributes = j.left.outputSet.intersect(j.right.outputSet)
failAnalysis(
s"""
|Failure when resolving conflicting references in Join:
|$plan
|Conflicting attributes: ${conflictingAttributes.mkString(",")}
|""".stripMargin)

case i: Intersect if !i.duplicateResolved =>
val conflictingAttributes = i.left.outputSet.intersect(i.right.outputSet)
failAnalysis(
s"""
|Failure when resolving conflicting references in Intersect:
|$plan
|Conflicting attributes: ${conflictingAttributes.mkString(",")}
|""".stripMargin)

case o if !o.resolved =>
failAnalysis(
s"unresolved operator ${operator.simpleString}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,10 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
// since the other rules might make two separate Unions operators adjacent.
Batch("Union", Once,
CombineUnions) ::
Batch("Replace Operators", FixedPoint(100),
ReplaceIntersectWithSemiJoin,
ReplaceDistinctWithAggregate) ::
Batch("Aggregate", FixedPoint(100),
ReplaceDistinctWithAggregate,
RemoveLiteralFromGroupExpressions) ::
Batch("Operator Optimizations", FixedPoint(100),
// Operator push down
Expand Down Expand Up @@ -124,18 +126,13 @@ object EliminateSerialization extends Rule[LogicalPlan] {
}

/**
* Pushes certain operations to both sides of a Union, Intersect or Except operator.
* Pushes certain operations to both sides of a Union or Except operator.
* Operations that are safe to pushdown are listed as follows.
* Union:
* Right now, Union means UNION ALL, which does not de-duplicate rows. So, it is
* safe to pushdown Filters and Projections through it. Once we add UNION DISTINCT,
* we will not be able to pushdown Projections.
*
* Intersect:
* It is not safe to pushdown Projections through it because we need to get the
* intersect of rows by comparing the entire rows. It is fine to pushdown Filters
* with deterministic condition.
*
* Except:
* It is not safe to pushdown Projections through it because we need to get the
* intersect of rows by comparing the entire rows. It is fine to pushdown Filters
Expand All @@ -153,7 +150,7 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {

/**
* Rewrites an expression so that it can be pushed to the right side of a
* Union, Intersect or Except operator. This method relies on the fact that the output attributes
* Union or Except operator. This method relies on the fact that the output attributes
* of a union/intersect/except are always equal to the left child's output.
*/
private def pushToRight[A <: Expression](e: A, rewrites: AttributeMap[Attribute]) = {
Expand Down Expand Up @@ -210,17 +207,6 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
}
Filter(nondeterministic, Union(newFirstChild +: newOtherChildren))

// Push down filter through INTERSECT
case Filter(condition, Intersect(left, right)) =>
val (deterministic, nondeterministic) = partitionByDeterministic(condition)
val rewrites = buildRewrites(left, right)
Filter(nondeterministic,
Intersect(
Filter(deterministic, left),
Filter(pushToRight(deterministic, rewrites), right)
)
)

// Push down filter through EXCEPT
case Filter(condition, Except(left, right)) =>
val (deterministic, nondeterministic) = partitionByDeterministic(condition)
Expand Down Expand Up @@ -1054,6 +1040,27 @@ object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] {
}
}

/**
* Replaces logical [[Intersect]] operator with a left-semi [[Join]] operator.
* {{{
* SELECT a1, a2 FROM Tab1 INTERSECT SELECT b1, b2 FROM Tab2
* ==> SELECT DISTINCT a1, a2 FROM Tab1 LEFT SEMI JOIN Tab2 ON a1<=>b1 AND a2<=>b2
* }}}
*
* Note:
* 1. This rule is only applicable to INTERSECT DISTINCT. Do not use it for INTERSECT ALL.
* 2. This rule has to be done after de-duplicating the attributes; otherwise, the generated
* join conditions will be incorrect.
*/
object ReplaceIntersectWithSemiJoin extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Intersect(left, right) =>
assert(left.output.size == right.output.size)
val joinCond = left.output.zip(right.output).map { case (l, r) => EqualNullSafe(l, r) }
Distinct(Join(left, right, LeftSemi, joinCond.reduceLeftOption(And)))
}
}

/**
* Removes literals from group expressions in [[Aggregate]], as they have no effect to the result
* but only makes the grouping key bigger.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
Expand Down Expand Up @@ -90,28 +91,38 @@ case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
}

abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
final override lazy val resolved: Boolean =
childrenResolved &&
left.output.length == right.output.length &&
left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType }
}
abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode

private[sql] object SetOperation {
def unapply(p: SetOperation): Option[(LogicalPlan, LogicalPlan)] = Some((p.left, p.right))
}

case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) {

def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty

override def output: Seq[Attribute] =
left.output.zip(right.output).map { case (leftAttr, rightAttr) =>
leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable)
}

// Intersect are only resolved if they don't introduce ambiguous expression ids,
// since the Optimizer will convert Intersect to Join.
override lazy val resolved: Boolean =
childrenResolved &&
left.output.length == right.output.length &&
left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } &&
duplicateResolved
}

case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) {
/** We don't use right.output because those rows get excluded from the set. */
override def output: Seq[Attribute] = left.output

override lazy val resolved: Boolean =
childrenResolved &&
left.output.length == right.output.length &&
left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType }
}

/** Factory for constructing new `Union` nodes. */
Expand Down Expand Up @@ -169,13 +180,13 @@ case class Join(
}
}

def selfJoinResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty
def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty

// Joins are only resolved if they don't introduce ambiguous expression ids.
override lazy val resolved: Boolean = {
childrenResolved &&
expressions.forall(_.resolved) &&
selfJoinResolved &&
duplicateResolved &&
condition.forall(_.dataType == BooleanType)
}
}
Expand Down Expand Up @@ -249,7 +260,7 @@ case class Range(
end: Long,
step: Long,
numSlices: Int,
output: Seq[Attribute]) extends LeafNode {
output: Seq[Attribute]) extends LeafNode with MultiInstanceRelation {
require(step != 0, "step cannot be 0")
val numElements: BigInt = {
val safeStart = BigInt(start)
Expand All @@ -262,6 +273,9 @@ case class Range(
}
}

override def newInstance(): Range =
Range(start, end, step, numSlices, output.map(_.newInstance()))

override def statistics: Statistics = {
val sizeInBytes = LongType.defaultSize * numElements
Statistics( sizeInBytes = sizeInBytes )
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,11 @@ class AnalysisSuite extends AnalysisTest {
checkAnalysis(plan, expected)
}

test("self intersect should resolve duplicate expression IDs") {
val plan = testRelation.intersect(testRelation)
assertAnalysisSuccess(plan)
}

test("SPARK-8654: invalid CAST in NULL IN(...) expression") {
val plan = Project(Alias(In(Literal(null), Seq(Literal(1), Literal(2))), "a")() :: Nil,
LocalRelation()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,9 @@ class AggregateOptimizeSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("Aggregate", FixedPoint(100),
ReplaceDistinctWithAggregate,
RemoveLiteralFromGroupExpressions) :: Nil
}

test("replace distinct with aggregate") {
val input = LocalRelation('a.int, 'b.int)

val query = Distinct(input)
val optimized = Optimize.execute(query.analyze)

val correctAnswer = Aggregate(input.output, input.output, input)

comparePlans(optimized, correctAnswer)
}

test("remove literals in grouping expression") {
val input = LocalRelation('a.int, 'b.int)

Expand Down
Loading

0 comments on commit 5f686cc

Please sign in to comment.