From 1707238601690fd0e8e173e2c47f1b4286644a29 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Fri, 17 Jul 2015 16:45:46 -0700 Subject: [PATCH] [SPARK-7026] [SQL] fix left semi join with equi key and non-equi condition When the `condition` extracted by `ExtractEquiJoinKeys` contain join Predicate for left semi join, we can not plan it as semiJoin. Such as SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.b = y.b AND x.a >= y.a + 2 Condition `x.a >= y.a + 2` can not evaluate on table `x`, so it throw errors Author: Daoyuan Wang Closes #5643 from adrian-wang/spark7026 and squashes the following commits: cc09809 [Daoyuan Wang] refactor semijoin and add plan test 575a7c8 [Daoyuan Wang] fix notserializable 27841de [Daoyuan Wang] fix rebase 10bf124 [Daoyuan Wang] fix style 72baa02 [Daoyuan Wang] fix style 8e0afca [Daoyuan Wang] merge commits for rebase --- .../spark/sql/execution/SparkStrategies.scala | 10 +- .../joins/BroadcastLeftSemiJoinHash.scala | 42 ++++----- .../sql/execution/joins/HashOuterJoin.scala | 3 +- .../sql/execution/joins/HashSemiJoin.scala | 91 +++++++++++++++++++ .../execution/joins/LeftSemiJoinHash.scala | 35 ++----- .../org/apache/spark/sql/SQLQuerySuite.scala | 12 +++ .../sql/execution/joins/SemiJoinSuite.scala | 74 +++++++++++++++ 7 files changed, 208 insertions(+), 59 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 73b463471e..240332a80a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -38,14 +38,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) if sqlContext.conf.autoBroadcastJoinThreshold > 0 && right.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold => - val semiJoin = joins.BroadcastLeftSemiJoinHash( - leftKeys, rightKeys, planLater(left), planLater(right)) - condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil + joins.BroadcastLeftSemiJoinHash( + leftKeys, rightKeys, planLater(left), planLater(right), condition) :: Nil // Find left semi joins where at least some predicates can be evaluated by matching join keys case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) => - val semiJoin = joins.LeftSemiJoinHash( - leftKeys, rightKeys, planLater(left), planLater(right)) - condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil + joins.LeftSemiJoinHash( + leftKeys, rightKeys, planLater(left), planLater(right), condition) :: Nil // no predicate can be evaluated by matching hash keys case logical.Join(left, right, LeftSemi, condition) => joins.LeftSemiJoinBNL(planLater(left), planLater(right), condition) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index f7b46d6888..2750f58b00 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -33,37 +33,27 @@ case class BroadcastLeftSemiJoinHash( leftKeys: Seq[Expression], rightKeys: Seq[Expression], left: SparkPlan, - right: SparkPlan) extends BinaryNode with HashJoin { - - override val buildSide: BuildSide = BuildRight - - override def output: Seq[Attribute] = left.output + right: SparkPlan, + condition: Option[Expression]) extends BinaryNode with HashSemiJoin { protected override def doExecute(): RDD[InternalRow] = { - val buildIter = buildPlan.execute().map(_.copy()).collect().toIterator - val hashSet = new java.util.HashSet[InternalRow]() - var currentRow: InternalRow = null + val buildIter = right.execute().map(_.copy()).collect().toIterator - // Create a Hash set of buildKeys - while (buildIter.hasNext) { - currentRow = buildIter.next() - val rowKey = buildSideKeyGenerator(currentRow) - if (!rowKey.anyNull) { - val keyExists = hashSet.contains(rowKey) - if (!keyExists) { - // rowKey may be not serializable (from codegen) - hashSet.add(rowKey.copy()) - } - } - } + if (condition.isEmpty) { + // rowKey may be not serializable (from codegen) + val hashSet = buildKeyHashSet(buildIter, copy = true) + val broadcastedRelation = sparkContext.broadcast(hashSet) - val broadcastedRelation = sparkContext.broadcast(hashSet) + left.execute().mapPartitions { streamIter => + hashSemiJoin(streamIter, broadcastedRelation.value) + } + } else { + val hashRelation = HashedRelation(buildIter, rightKeyGenerator) + val broadcastedRelation = sparkContext.broadcast(hashRelation) - streamedPlan.execute().mapPartitions { streamIter => - val joinKeys = streamSideKeyGenerator() - streamIter.filter(current => { - !joinKeys(current).anyNull && broadcastedRelation.value.contains(joinKeys.currentValue) - }) + left.execute().mapPartitions { streamIter => + hashSemiJoin(streamIter, broadcastedRelation.value) + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 0522ee85ee..74a7db7761 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -65,8 +65,7 @@ override def outputPartitioning: Partitioning = joinType match { @transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length) @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length) @transient private[this] lazy val boundCondition = - condition.map( - newPredicate(_, left.output ++ right.output)).getOrElse((row: InternalRow) => true) + newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala // iterator for performance purpose. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala new file mode 100644 index 0000000000..1b983bc3a9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.SparkPlan + + +trait HashSemiJoin { + self: SparkPlan => + val leftKeys: Seq[Expression] + val rightKeys: Seq[Expression] + val left: SparkPlan + val right: SparkPlan + val condition: Option[Expression] + + override def output: Seq[Attribute] = left.output + + @transient protected lazy val rightKeyGenerator: Projection = + newProjection(rightKeys, right.output) + + @transient protected lazy val leftKeyGenerator: () => MutableProjection = + newMutableProjection(leftKeys, left.output) + + @transient private lazy val boundCondition = + newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) + + protected def buildKeyHashSet( + buildIter: Iterator[InternalRow], + copy: Boolean): java.util.Set[InternalRow] = { + val hashSet = new java.util.HashSet[InternalRow]() + var currentRow: InternalRow = null + + // Create a Hash set of buildKeys + while (buildIter.hasNext) { + currentRow = buildIter.next() + val rowKey = rightKeyGenerator(currentRow) + if (!rowKey.anyNull) { + val keyExists = hashSet.contains(rowKey) + if (!keyExists) { + if (copy) { + hashSet.add(rowKey.copy()) + } else { + // rowKey may be not serializable (from codegen) + hashSet.add(rowKey) + } + } + } + } + hashSet + } + + protected def hashSemiJoin( + streamIter: Iterator[InternalRow], + hashedRelation: HashedRelation): Iterator[InternalRow] = { + val joinKeys = leftKeyGenerator() + val joinedRow = new JoinedRow + streamIter.filter(current => { + lazy val rowBuffer = hashedRelation.get(joinKeys.currentValue) + !joinKeys(current).anyNull && rowBuffer != null && rowBuffer.exists { + (build: InternalRow) => boundCondition(joinedRow(current, build)) + } + }) + } + + protected def hashSemiJoin( + streamIter: Iterator[InternalRow], + hashSet: java.util.Set[InternalRow]): Iterator[InternalRow] = { + val joinKeys = leftKeyGenerator() + val joinedRow = new JoinedRow + streamIter.filter(current => { + !joinKeys(current.copy()).anyNull && hashSet.contains(joinKeys.currentValue) + }) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala index 611ba928a1..9eaac817d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} @@ -34,36 +34,21 @@ case class LeftSemiJoinHash( leftKeys: Seq[Expression], rightKeys: Seq[Expression], left: SparkPlan, - right: SparkPlan) extends BinaryNode with HashJoin { - - override val buildSide: BuildSide = BuildRight + right: SparkPlan, + condition: Option[Expression]) extends BinaryNode with HashSemiJoin { override def requiredChildDistribution: Seq[ClusteredDistribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - override def output: Seq[Attribute] = left.output - protected override def doExecute(): RDD[InternalRow] = { - buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => - val hashSet = new java.util.HashSet[InternalRow]() - var currentRow: InternalRow = null - - // Create a Hash set of buildKeys - while (buildIter.hasNext) { - currentRow = buildIter.next() - val rowKey = buildSideKeyGenerator(currentRow) - if (!rowKey.anyNull) { - val keyExists = hashSet.contains(rowKey) - if (!keyExists) { - hashSet.add(rowKey) - } - } + right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) => + if (condition.isEmpty) { + val hashSet = buildKeyHashSet(buildIter, copy = false) + hashSemiJoin(streamIter, hashSet) + } else { + val hashRelation = HashedRelation(buildIter, rightKeyGenerator) + hashSemiJoin(streamIter, hashRelation) } - - val joinKeys = streamSideKeyGenerator() - streamIter.filter(current => { - !joinKeys(current).anyNull && hashSet.contains(joinKeys.currentValue) - }) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5b8b70ed5a..61d5f2061a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -395,6 +395,18 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { ) } + test("left semi greater than predicate and equal operator") { + checkAnswer( + sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.b = y.b and x.a >= y.a + 2"), + Seq(Row(3, 1), Row(3, 2)) + ) + + checkAnswer( + sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.b = y.a and x.a >= y.b + 1"), + Seq(Row(2, 1), Row(2, 2), Row(3, 1), Row(3, 2)) + ) + } + test("index into array of arrays") { checkAnswer( sql( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala new file mode 100644 index 0000000000..927e85a7db --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.{LessThan, Expression} +import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} + + +class SemiJoinSuite extends SparkPlanTest{ + val left = Seq( + (1, 2.0), + (1, 2.0), + (2, 1.0), + (2, 1.0), + (3, 3.0) + ).toDF("a", "b") + + val right = Seq( + (2, 3.0), + (2, 3.0), + (3, 2.0), + (4, 1.0) + ).toDF("c", "d") + + val leftKeys: List[Expression] = 'a :: Nil + val rightKeys: List[Expression] = 'c :: Nil + val condition = Some(LessThan('b, 'd)) + + test("left semi join hash") { + checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => + LeftSemiJoinHash(leftKeys, rightKeys, left, right, condition), + Seq( + (2, 1.0), + (2, 1.0) + ).map(Row.fromTuple)) + } + + test("left semi join BNL") { + checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => + LeftSemiJoinBNL(left, right, condition), + Seq( + (1, 2.0), + (1, 2.0), + (2, 1.0), + (2, 1.0) + ).map(Row.fromTuple)) + } + + test("broadcast left semi join hash") { + checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => + BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, condition), + Seq( + (2, 1.0), + (2, 1.0) + ).map(Row.fromTuple)) + } +}