Skip to content

Commit

Permalink
[SPARK-21164][SQL] Remove isTableSample from Sample and isGenerated f…
Browse files Browse the repository at this point in the history
…rom Alias and AttributeReference

## What changes were proposed in this pull request?
`isTableSample` and `isGenerated ` were introduced for SQL Generation respectively by apache#11148 and apache#11050

Since SQL Generation is removed, we do not need to keep `isTableSample`.

## How was this patch tested?
The existing test cases

Author: Xiao Li <[email protected]>

Closes apache#18379 from gatorsmile/CleanSample.
  • Loading branch information
gatorsmile committed Jun 23, 2017
1 parent 13c2a4f commit 03eb611
Show file tree
Hide file tree
Showing 15 changed files with 40 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,7 @@ class Analyzer(

def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = {
expressions.map {
case a: Alias => Alias(a.child, a.name)(isGenerated = a.isGenerated)
case a: Alias => Alias(a.child, a.name)()
case other => other
}
}
Expand Down Expand Up @@ -1368,7 +1368,7 @@ class Analyzer(
val aggregatedCondition =
Aggregate(
grouping,
Alias(havingCondition, "havingCondition")(isGenerated = true) :: Nil,
Alias(havingCondition, "havingCondition")() :: Nil,
child)
val resolvedOperator = execute(aggregatedCondition)
def resolvedAggregateFilter =
Expand Down Expand Up @@ -1424,7 +1424,7 @@ class Analyzer(
try {
val unresolvedSortOrders = sortOrder.filter(s => !s.resolved || containsAggregate(s))
val aliasedOrdering =
unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")(isGenerated = true))
unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")())
val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering)
val resolvedAggregate: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate]
val resolvedAliasedOrdering: Seq[Alias] =
Expand Down Expand Up @@ -1935,7 +1935,7 @@ class Analyzer(
leafNondeterministic.distinct.map { e =>
val ne = e match {
case n: NamedExpression => n
case _ => Alias(e, "_nondeterministic")(isGenerated = true)
case _ => Alias(e, "_nondeterministic")()
}
e -> ne
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,6 @@ trait NamedExpression extends Expression {
/** Returns the metadata when an expression is a reference to another expression with metadata. */
def metadata: Metadata = Metadata.empty

/** Returns true if the expression is generated by Catalyst */
def isGenerated: java.lang.Boolean = false

/** Returns a copy of this expression with a new `exprId`. */
def newInstance(): NamedExpression

Expand Down Expand Up @@ -128,13 +125,11 @@ abstract class Attribute extends LeafExpression with NamedExpression with NullIn
* qualified way. Consider the examples tableName.name, subQueryAlias.name.
* tableName and subQueryAlias are possible qualifiers.
* @param explicitMetadata Explicit metadata associated with this alias that overwrites child's.
* @param isGenerated A flag to indicate if this alias is generated by Catalyst
*/
case class Alias(child: Expression, name: String)(
val exprId: ExprId = NamedExpression.newExprId,
val qualifier: Option[String] = None,
val explicitMetadata: Option[Metadata] = None,
override val isGenerated: java.lang.Boolean = false)
val explicitMetadata: Option[Metadata] = None)
extends UnaryExpression with NamedExpression {

// Alias(Generator, xx) need to be transformed into Generate(generator, ...)
Expand All @@ -159,13 +154,11 @@ case class Alias(child: Expression, name: String)(
}

def newInstance(): NamedExpression =
Alias(child, name)(
qualifier = qualifier, explicitMetadata = explicitMetadata, isGenerated = isGenerated)
Alias(child, name)(qualifier = qualifier, explicitMetadata = explicitMetadata)

override def toAttribute: Attribute = {
if (resolved) {
AttributeReference(name, child.dataType, child.nullable, metadata)(
exprId, qualifier, isGenerated)
AttributeReference(name, child.dataType, child.nullable, metadata)(exprId, qualifier)
} else {
UnresolvedAttribute(name)
}
Expand All @@ -174,7 +167,7 @@ case class Alias(child: Expression, name: String)(
override def toString: String = s"$child AS $name#${exprId.id}$typeSuffix"

override protected final def otherCopyArgs: Seq[AnyRef] = {
exprId :: qualifier :: explicitMetadata :: isGenerated :: Nil
exprId :: qualifier :: explicitMetadata :: Nil
}

override def hashCode(): Int = {
Expand Down Expand Up @@ -207,16 +200,14 @@ case class Alias(child: Expression, name: String)(
* @param qualifier An optional string that can be used to referred to this attribute in a fully
* qualified way. Consider the examples tableName.name, subQueryAlias.name.
* tableName and subQueryAlias are possible qualifiers.
* @param isGenerated A flag to indicate if this reference is generated by Catalyst
*/
case class AttributeReference(
name: String,
dataType: DataType,
nullable: Boolean = true,
override val metadata: Metadata = Metadata.empty)(
val exprId: ExprId = NamedExpression.newExprId,
val qualifier: Option[String] = None,
override val isGenerated: java.lang.Boolean = false)
val qualifier: Option[String] = None)
extends Attribute with Unevaluable {

/**
Expand Down Expand Up @@ -253,8 +244,7 @@ case class AttributeReference(
}

override def newInstance(): AttributeReference =
AttributeReference(name, dataType, nullable, metadata)(
qualifier = qualifier, isGenerated = isGenerated)
AttributeReference(name, dataType, nullable, metadata)(qualifier = qualifier)

/**
* Returns a copy of this [[AttributeReference]] with changed nullability.
Expand All @@ -263,15 +253,15 @@ case class AttributeReference(
if (nullable == newNullability) {
this
} else {
AttributeReference(name, dataType, newNullability, metadata)(exprId, qualifier, isGenerated)
AttributeReference(name, dataType, newNullability, metadata)(exprId, qualifier)
}
}

override def withName(newName: String): AttributeReference = {
if (name == newName) {
this
} else {
AttributeReference(newName, dataType, nullable, metadata)(exprId, qualifier, isGenerated)
AttributeReference(newName, dataType, nullable, metadata)(exprId, qualifier)
}
}

Expand All @@ -282,24 +272,24 @@ case class AttributeReference(
if (newQualifier == qualifier) {
this
} else {
AttributeReference(name, dataType, nullable, metadata)(exprId, newQualifier, isGenerated)
AttributeReference(name, dataType, nullable, metadata)(exprId, newQualifier)
}
}

def withExprId(newExprId: ExprId): AttributeReference = {
if (exprId == newExprId) {
this
} else {
AttributeReference(name, dataType, nullable, metadata)(newExprId, qualifier, isGenerated)
AttributeReference(name, dataType, nullable, metadata)(newExprId, qualifier)
}
}

override def withMetadata(newMetadata: Metadata): Attribute = {
AttributeReference(name, dataType, nullable, newMetadata)(exprId, qualifier, isGenerated)
AttributeReference(name, dataType, nullable, newMetadata)(exprId, qualifier)
}

override protected final def otherCopyArgs: Seq[AnyRef] = {
exprId :: qualifier :: isGenerated :: Nil
exprId :: qualifier :: Nil
}

/** Used to signal the column used to calculate an eventTime watermark (e.g. a#1-T{delayMs}) */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
// Aggregation strategy can handle queries with a single distinct group.
if (distinctAggGroups.size > 1) {
// Create the attributes for the grouping id and the group by clause.
val gid = AttributeReference("gid", IntegerType, nullable = false)(isGenerated = true)
val gid = AttributeReference("gid", IntegerType, nullable = false)()
val groupByMap = a.groupingExpressions.collect {
case ne: NamedExpression => ne -> ne.toAttribute
case e => e -> AttributeReference(e.sql, e.dataType, e.nullable)()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
validate(fraction >= 0.0 - eps && fraction <= 1.0 + eps,
s"Sampling fraction ($fraction) must be on interval [0, 1]",
ctx)
Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, query)(true)
Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, query)
}

ctx.sampleType.getType match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@ object PhysicalOperation extends PredicateHelper {
expr.transform {
case a @ Alias(ref: AttributeReference, name) =>
aliases.get(ref)
.map(Alias(_, name)(a.exprId, a.qualifier, isGenerated = a.isGenerated))
.map(Alias(_, name)(a.exprId, a.qualifier))
.getOrElse(a)

case a: AttributeReference =>
aliases.get(a)
.map(Alias(_, a.name)(a.exprId, a.qualifier, isGenerated = a.isGenerated)).getOrElse(a)
.map(Alias(_, a.name)(a.exprId, a.qualifier)).getOrElse(a)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
// normalize that for equality testing, by assigning expr id from 0 incrementally. The
// alias name doesn't matter and should be erased.
val normalizedChild = QueryPlan.normalizeExprId(a.child, allAttributes)
Alias(normalizedChild, "")(ExprId(id), a.qualifier, isGenerated = a.isGenerated)
Alias(normalizedChild, "")(ExprId(id), a.qualifier)

case ar: AttributeReference if allAttributes.indexOf(ar.exprId) == -1 =>
// Top level `AttributeReference` may also be used for output like `Alias`, we should
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with QueryPlanConstrai
nameParts: Seq[String],
resolver: Resolver,
attribute: Attribute): Option[(Attribute, List[String])] = {
if (!attribute.isGenerated && resolver(attribute.name, nameParts.head)) {
if (resolver(attribute.name, nameParts.head)) {
Option((attribute.withName(nameParts.head), nameParts.tail.toList))
} else {
None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -807,15 +807,13 @@ case class SubqueryAlias(
* @param withReplacement Whether to sample with replacement.
* @param seed the random seed
* @param child the LogicalPlan
* @param isTableSample Is created from TABLESAMPLE in the parser.
*/
case class Sample(
lowerBound: Double,
upperBound: Double,
withReplacement: Boolean,
seed: Long,
child: LogicalPlan)(
val isTableSample: java.lang.Boolean = false) extends UnaryNode {
child: LogicalPlan) extends UnaryNode {

val eps = RandomSampler.roundingEpsilon
val fraction = upperBound - lowerBound
Expand All @@ -842,8 +840,6 @@ case class Sample(
// Don't propagate column stats, because we don't know the distribution after a sample operation
Statistics(sizeInBytes, sampledRowCount, hints = childStats.hints)
}

override protected def otherCopyArgs: Seq[AnyRef] = isTableSample :: Nil
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ class AnalysisErrorSuite extends AnalysisTest {
val plan5 = Filter(
Exists(
Sample(0.0, 0.5, false, 1L,
Filter(EqualTo(UnresolvedAttribute("a"), b), LocalRelation(b)))().select('b)
Filter(EqualTo(UnresolvedAttribute("a"), b), LocalRelation(b))).select('b)
),
LocalRelation(a))
assertAnalysisError(plan5,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite {

// Other unary operations
testUnaryOperatorInStreamingPlan(
"sample", Sample(0.1, 1, true, 1L, _)(), expectedMsg = "sampling")
"sample", Sample(0.1, 1, true, 1L, _), expectedMsg = "sampling")
testUnaryOperatorInStreamingPlan(
"window", Window(Nil, Nil, Nil, _), expectedMsg = "non-time-based windows")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -349,14 +349,14 @@ class ColumnPruningSuite extends PlanTest {
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
val x = testRelation.subquery('x)

val query1 = Sample(0.0, 0.6, false, 11L, x)().select('a)
val query1 = Sample(0.0, 0.6, false, 11L, x).select('a)
val optimized1 = Optimize.execute(query1.analyze)
val expected1 = Sample(0.0, 0.6, false, 11L, x.select('a))()
val expected1 = Sample(0.0, 0.6, false, 11L, x.select('a))
comparePlans(optimized1, expected1.analyze)

val query2 = Sample(0.0, 0.6, false, 11L, x)().select('a as 'aa)
val query2 = Sample(0.0, 0.6, false, 11L, x).select('a as 'aa)
val optimized2 = Optimize.execute(query2.analyze)
val expected2 = Sample(0.0, 0.6, false, 11L, x.select('a))().select('a as 'aa)
val expected2 = Sample(0.0, 0.6, false, 11L, x.select('a)).select('a as 'aa)
comparePlans(optimized2, expected2.analyze)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -411,9 +411,9 @@ class PlanParserSuite extends AnalysisTest {
assertEqual(s"$sql tablesample(100 rows)",
table("t").limit(100).select(star()))
assertEqual(s"$sql tablesample(43 percent) as x",
Sample(0, .43d, withReplacement = false, 10L, table("t").as("x"))(true).select(star()))
Sample(0, .43d, withReplacement = false, 10L, table("t").as("x")).select(star()))
assertEqual(s"$sql tablesample(bucket 4 out of 10) as x",
Sample(0, .4d, withReplacement = false, 10L, table("t").as("x"))(true).select(star()))
Sample(0, .4d, withReplacement = false, 10L, table("t").as("x")).select(star()))
intercept(s"$sql tablesample(bucket 4 out of 10 on x) as x",
"TABLESAMPLE(BUCKET x OUT OF y ON colname) is not supported")
intercept(s"$sql tablesample(bucket 11 out of 10) as x",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,14 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper {
*/
protected def normalizePlan(plan: LogicalPlan): LogicalPlan = {
plan transform {
case filter @ Filter(condition: Expression, child: LogicalPlan) =>
Filter(splitConjunctivePredicates(condition).map(rewriteEqual(_)).sortBy(_.hashCode())
case Filter(condition: Expression, child: LogicalPlan) =>
Filter(splitConjunctivePredicates(condition).map(rewriteEqual).sortBy(_.hashCode())
.reduce(And), child)
case sample: Sample =>
sample.copy(seed = 0L)(true)
case join @ Join(left, right, joinType, condition) if condition.isDefined =>
sample.copy(seed = 0L)
case Join(left, right, joinType, condition) if condition.isDefined =>
val newCondition =
splitConjunctivePredicates(condition.get).map(rewriteEqual(_)).sortBy(_.hashCode())
splitConjunctivePredicates(condition.get).map(rewriteEqual).sortBy(_.hashCode())
.reduce(And)
Join(left, right, joinType, Some(newCondition))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,14 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase {
}

test("sample estimation") {
val sample = Sample(0.0, 0.5, withReplacement = false, (math.random * 1000).toLong, plan)()
val sample = Sample(0.0, 0.5, withReplacement = false, (math.random * 1000).toLong, plan)
checkStats(sample, Statistics(sizeInBytes = 60, rowCount = Some(5)))

// Child doesn't have rowCount in stats
val childStats = Statistics(sizeInBytes = 120)
val childPlan = DummyLogicalPlan(childStats, childStats)
val sample2 =
Sample(0.0, 0.11, withReplacement = false, (math.random * 1000).toLong, childPlan)()
Sample(0.0, 0.11, withReplacement = false, (math.random * 1000).toLong, childPlan)
checkStats(sample2, Statistics(sizeInBytes = 14))
}

Expand Down
4 changes: 2 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1807,7 +1807,7 @@ class Dataset[T] private[sql](
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = {
withTypedPlan {
Sample(0.0, fraction, withReplacement, seed, logicalPlan)()
Sample(0.0, fraction, withReplacement, seed, logicalPlan)
}
}

Expand Down Expand Up @@ -1863,7 +1863,7 @@ class Dataset[T] private[sql](
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
normalizedCumWeights.sliding(2).map { x =>
new Dataset[T](
sparkSession, Sample(x(0), x(1), withReplacement = false, seed, plan)(), encoder)
sparkSession, Sample(x(0), x(1), withReplacement = false, seed, plan), encoder)
}.toArray
}

Expand Down

0 comments on commit 03eb611

Please sign in to comment.