Skip to content

Commit

Permalink
[SPARK-5454] More robust handling of self joins
Browse files Browse the repository at this point in the history
Also I fix a bunch of bad output in test cases.

Author: Michael Armbrust <[email protected]>

Closes apache#4520 from marmbrus/selfJoin and squashes the following commits:

4f4a85c [Michael Armbrust] comments
49c8e26 [Michael Armbrust] fix tests
6fc38de [Michael Armbrust] fix style
55d64b3 [Michael Armbrust] fix dataframe selfjoins
  • Loading branch information
marmbrus committed Feb 11, 2015
1 parent 03bf704 commit a60d2b7
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,11 @@ class Analyzer(catalog: Catalog,
val extendedRules: Seq[Rule[LogicalPlan]] = Nil

lazy val batches: Seq[Batch] = Seq(
Batch("MultiInstanceRelations", Once,
NewRelationInstances),
Batch("Resolution", fixedPoint,
ResolveReferences ::
ResolveRelations ::
ResolveReferences ::
ResolveGroupingAnalytics ::
ResolveSortReferences ::
NewRelationInstances ::
ImplicitGenerate ::
ResolveFunctions ::
GlobalAggregates ::
Expand Down Expand Up @@ -285,6 +282,27 @@ class Analyzer(catalog: Catalog,
}
)

// Special handling for cases when self-join introduce duplicate expression ids.
case j @ Join(left, right, _, _) if left.outputSet.intersect(right.outputSet).nonEmpty =>
val conflictingAttributes = left.outputSet.intersect(right.outputSet)

val (oldRelation, newRelation, attributeRewrites) = right.collect {
case oldVersion: MultiInstanceRelation
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
val newVersion = oldVersion.newInstance()
val newAttributes = AttributeMap(oldVersion.output.zip(newVersion.output))
(oldVersion, newVersion, newAttributes)
}.head // Only handle first case found, others will be fixed on the next pass.

val newRight = right transformUp {
case r if r == oldRelation => newRelation
case other => other transformExpressions {
case a: Attribute => attributeRewrites.get(a).getOrElse(a)
}
}

j.copy(right = newRight)

case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString}")
q transformExpressionsUp {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,28 +26,9 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
* produced by distinct operators in a query tree as this breaks the guarantee that expression
* ids, which are used to differentiate attributes, are unique.
*
* Before analysis, all operators that include this trait will be asked to produce a new version
* During analysis, operators that include this trait may be asked to produce a new version
* of itself with globally unique expression ids.
*/
trait MultiInstanceRelation {
def newInstance(): this.type
}

/**
* If any MultiInstanceRelation appears more than once in the query plan then the plan is updated so
* that each instance has unique expression ids for the attributes produced.
*/
object NewRelationInstances extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
val localRelations = plan collect { case l: MultiInstanceRelation => l}
val multiAppearance = localRelations
.groupBy(identity[MultiInstanceRelation])
.filter { case (_, ls) => ls.size > 1 }
.map(_._1)
.toSet

plan transform {
case l: MultiInstanceRelation if multiAppearance.contains(l) => l.newInstance()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,9 @@ class PlanTest extends FunSuite {
* we must normalize them to check if two different queries are identical.
*/
protected def normalizeExprIds(plan: LogicalPlan) = {
val list = plan.flatMap(_.expressions.flatMap(_.references).map(_.exprId.id))
val minId = if (list.isEmpty) 0 else list.min
plan transformAllExpressions {
case a: AttributeReference =>
AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(a.exprId.id - minId))
AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0))
}
}

Expand Down
2 changes: 2 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
case _ =>
}

@transient
protected[sql] val cacheManager = new CacheManager(this)

/**
Expand Down Expand Up @@ -159,6 +160,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* DataTypes.StringType);
* }}}
*/
@transient
val udf: UDFRegistration = new UDFRegistration(this)

/** Returns true if the table is currently cached in-memory. */
Expand Down
3 changes: 3 additions & 0 deletions sql/core/src/test/resources/log4j.properties
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ log4j.appender.FA.Threshold = INFO
log4j.additivity.parquet.hadoop.ParquetRecordReader=false
log4j.logger.parquet.hadoop.ParquetRecordReader=OFF

log4j.additivity.parquet.hadoop.ParquetOutputCommitter=false
log4j.logger.parquet.hadoop.ParquetOutputCommitter=OFF

log4j.additivity.org.apache.hadoop.hive.serde2.lazy.LazyStruct=false
log4j.logger.org.apache.hadoop.hive.serde2.lazy.LazyStruct=OFF

Expand Down
10 changes: 10 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.types._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.logicalPlanToSparkQuery
import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.test.TestSQLContext.sql


class DataFrameSuite extends QueryTest {
Expand Down Expand Up @@ -88,6 +89,15 @@ class DataFrameSuite extends QueryTest {
testData.collect().toSeq)
}

test("self join") {
val df1 = testData.select(testData("key")).as('df1)
val df2 = testData.select(testData("key")).as('df2)

checkAnswer(
df1.join(df2, $"df1.key" === $"df2.key"),
sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key").collect().toSeq)
}

test("selectExpr") {
checkAnswer(
testData.selectExpr("abs(key)", "value"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,9 @@ class PlanTest extends FunSuite {
* we must normalize them to check if two different queries are identical.
*/
protected def normalizeExprIds(plan: LogicalPlan) = {
val list = plan.flatMap(_.expressions.flatMap(_.references).map(_.exprId.id))
val minId = if (list.isEmpty) 0 else list.min
plan transformAllExpressions {
case a: AttributeReference =>
AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(a.exprId.id - minId))
AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0))
}
}

Expand Down

0 comments on commit a60d2b7

Please sign in to comment.