Skip to content

Commit

Permalink
[SPARK-24500][SQL] Make sure streams are materialized during Tree tra…
Browse files Browse the repository at this point in the history
…nsforms.

## What changes were proposed in this pull request?
If you construct catalyst trees using `scala.collection.immutable.Stream` you can run into situations where valid transformations do not seem to have any effect. There are two causes for this behavior:
- `Stream` is evaluated lazily. Note that default implementation will generally only evaluate a function for the first element (this makes testing a bit tricky).
- `TreeNode` and `QueryPlan` use side effects to detect if a tree has changed. Mapping over a stream is lazy and does not need to trigger this side effect. If this happens the node will invalidly assume that it did not change and return itself instead if the newly created node (this is for GC reasons).

This PR fixes this issue by forcing materialization on streams in `TreeNode` and `QueryPlan`.

## How was this patch tested?
Unit tests were added to `TreeNodeSuite` and `LogicalPlanSuite`. An integration test was added to the `PlannerSuite`

Author: Herman van Hovell <[email protected]>

Closes apache#21539 from hvanhovell/SPARK-24500.
  • Loading branch information
hvanhovell authored and cloud-fan committed Jun 13, 2018
1 parent 7703b46 commit 299d297
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
case Some(value) => Some(recursiveTransform(value))
case m: Map[_, _] => m
case d: DataType => d // Avoid unpacking Structs
case stream: Stream[_] => stream.map(recursiveTransform).force
case seq: Traversable[_] => seq.map(recursiveTransform)
case other: AnyRef => other
case null => null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,44 +199,33 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
var changed = false
val remainingNewChildren = newChildren.toBuffer
val remainingOldChildren = children.toBuffer
def mapTreeNode(node: TreeNode[_]): TreeNode[_] = {
val newChild = remainingNewChildren.remove(0)
val oldChild = remainingOldChildren.remove(0)
if (newChild fastEquals oldChild) {
oldChild
} else {
changed = true
newChild
}
}
def mapChild(child: Any): Any = child match {
case arg: TreeNode[_] if containsChild(arg) => mapTreeNode(arg)
case nonChild: AnyRef => nonChild
case null => null
}
val newArgs = mapProductIterator {
case s: StructType => s // Don't convert struct types to some other type of Seq[StructField]
// Handle Seq[TreeNode] in TreeNode parameters.
case s: Seq[_] => s.map {
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = remainingNewChildren.remove(0)
val oldChild = remainingOldChildren.remove(0)
if (newChild fastEquals oldChild) {
oldChild
} else {
changed = true
newChild
}
case nonChild: AnyRef => nonChild
case null => null
}
case m: Map[_, _] => m.mapValues {
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = remainingNewChildren.remove(0)
val oldChild = remainingOldChildren.remove(0)
if (newChild fastEquals oldChild) {
oldChild
} else {
changed = true
newChild
}
case nonChild: AnyRef => nonChild
case null => null
}.view.force // `mapValues` is lazy and we need to force it to materialize
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = remainingNewChildren.remove(0)
val oldChild = remainingOldChildren.remove(0)
if (newChild fastEquals oldChild) {
oldChild
} else {
changed = true
newChild
}
case s: Stream[_] =>
// Stream is lazy so we need to force materialization
s.map(mapChild).force
case s: Seq[_] =>
s.map(mapChild)
case m: Map[_, _] =>
// `mapValues` is lazy and we need to force it to materialize
m.mapValues(mapChild).view.force
case arg: TreeNode[_] if containsChild(arg) => mapTreeNode(arg)
case nonChild: AnyRef => nonChild
case null => null
}
Expand Down Expand Up @@ -301,6 +290,37 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
def mapChildren(f: BaseType => BaseType): BaseType = {
if (children.nonEmpty) {
var changed = false
def mapChild(child: Any): Any = child match {
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = f(arg.asInstanceOf[BaseType])
if (!(newChild fastEquals arg)) {
changed = true
newChild
} else {
arg
}
case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) =>
val newChild1 = if (containsChild(arg1)) {
f(arg1.asInstanceOf[BaseType])
} else {
arg1.asInstanceOf[BaseType]
}

val newChild2 = if (containsChild(arg2)) {
f(arg2.asInstanceOf[BaseType])
} else {
arg2.asInstanceOf[BaseType]
}

if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) {
changed = true
(newChild1, newChild2)
} else {
tuple
}
case other => other
}

val newArgs = mapProductIterator {
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = f(arg.asInstanceOf[BaseType])
Expand Down Expand Up @@ -330,36 +350,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
case other => other
}.view.force // `mapValues` is lazy and we need to force it to materialize
case d: DataType => d // Avoid unpacking Structs
case args: Traversable[_] => args.map {
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = f(arg.asInstanceOf[BaseType])
if (!(newChild fastEquals arg)) {
changed = true
newChild
} else {
arg
}
case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) =>
val newChild1 = if (containsChild(arg1)) {
f(arg1.asInstanceOf[BaseType])
} else {
arg1.asInstanceOf[BaseType]
}

val newChild2 = if (containsChild(arg2)) {
f(arg2.asInstanceOf[BaseType])
} else {
arg2.asInstanceOf[BaseType]
}

if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) {
changed = true
(newChild1, newChild2)
} else {
tuple
}
case other => other
}
case args: Stream[_] => args.map(mapChild).force // Force materialization on stream
case args: Traversable[_] => args.map(mapChild)
case nonChild: AnyRef => nonChild
case null => null
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.plans

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Coalesce, Literal, NamedExpression}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.IntegerType

Expand Down Expand Up @@ -101,4 +101,22 @@ class LogicalPlanSuite extends SparkFunSuite {
assert(TestBinaryRelation(relation, incrementalRelation).isStreaming === true)
assert(TestBinaryRelation(incrementalRelation, incrementalRelation).isStreaming)
}

test("transformExpressions works with a Stream") {
val id1 = NamedExpression.newExprId
val id2 = NamedExpression.newExprId
val plan = Project(Stream(
Alias(Literal(1), "a")(exprId = id1),
Alias(Literal(2), "b")(exprId = id2)),
OneRowRelation())
val result = plan.transformExpressions {
case Literal(v: Int, IntegerType) if v != 1 =>
Literal(v + 1, IntegerType)
}
val expected = Project(Stream(
Alias(Literal(1), "a")(exprId = id1),
Alias(Literal(3), "b")(exprId = id2)),
OneRowRelation())
assert(result.sameResult(expected))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ import org.json4s.jackson.JsonMethods._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, FunctionResource, JarResource}
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.dsl.expressions.DslString
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.plans.{LeftOuter, NaturalJoin}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Union}
import org.apache.spark.sql.catalyst.plans.physical.{IdentityBroadcastMode, RoundRobinPartitioning, SinglePartition}
import org.apache.spark.sql.types.{BooleanType, DoubleType, FloatType, IntegerType, Metadata, NullType, StringType, StructField, StructType}
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel

case class Dummy(optKey: Option[Expression]) extends Expression with CodegenFallback {
Expand Down Expand Up @@ -574,4 +574,25 @@ class TreeNodeSuite extends SparkFunSuite {
val right = JsonMethods.parse(rightJson)
assert(left == right)
}

test("transform works on stream of children") {
val before = Coalesce(Stream(Literal(1), Literal(2)))
// Note it is a bit tricky to exhibit the broken behavior. Basically we want to create the
// situation in which the TreeNode.mapChildren function's change detection is not triggered. A
// stream's first element is typically materialized, so in order to not trip the TreeNode change
// detection logic, we should not change the first element in the sequence.
val result = before.transform {
case Literal(v: Int, IntegerType) if v != 1 =>
Literal(v + 1, IntegerType)
}
val expected = Coalesce(Stream(Literal(1), Literal(3)))
assert(result === expected)
}

test("withNewChildren on stream of children") {
val before = Coalesce(Stream(Literal(1), Literal(2)))
val result = before.withNewChildren(Stream(Literal(1), Literal(3)))
val expected = Coalesce(Stream(Literal(1), Literal(3)))
assert(result === expected)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{execution, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range, Repartition, Sort}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range, Repartition, Sort, Union}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec}
Expand Down Expand Up @@ -679,6 +679,13 @@ class PlannerSuite extends SharedSQLContext {
}
assert(rangeExecInZeroPartition.head.outputPartitioning == UnknownPartitioning(0))
}

test("SPARK-24500: create union with stream of children") {
val df = Union(Stream(
Range(1, 1, 1, 1),
Range(1, 2, 1, 1)))
df.queryExecution.executedPlan.execute()
}
}

// Used for unit-testing EnsureRequirements
Expand Down

0 comments on commit 299d297

Please sign in to comment.