Skip to content

Commit

Permalink
[SPARK-35544][SQL] Add tree pattern pruning to Analyzer rules
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Added the following TreePattern enums:
- AGGREGATE_EXPRESSION
- ALIAS
- GROUPING_ANALYTICS
- GENERATOR
- HIGH_ORDER_FUNCTION
- LAMBDA_FUNCTION
- NEW_INSTANCE
- PIVOT
- PYTHON_UDF
- TIME_WINDOW
- TIME_ZONE_AWARE_EXPRESSION
- UP_CAST
- COMMAND
- EVENT_TIME_WATERMARK
- UNRESOLVED_RELATION
- WITH_WINDOW_DEFINITION
- UNRESOLVED_ALIAS
- UNRESOLVED_ATTRIBUTE
- UNRESOLVED_DESERIALIZER
- UNRESOLVED_ORDINAL
- UNRESOLVED_FUNCTION
- UNRESOLVED_HINT
- UNRESOLVED_SUBQUERY_COLUMN_ALIAS
- UNRESOLVED_FUNC

Added tree pattern pruning to the following Analyzer rules:
- ResolveBinaryArithmetic
- WindowsSubstitution
- ResolveAliases
- ResolveGroupingAnalytics
- ResolvePivot
- ResolveOrdinalInOrderByAndGroupBy
- LookupFunction
- ResolveSubquery
- ResolveSubqueryColumnAliases
- ApplyCharTypePadding
- UpdateOuterReferences
- ResolveCreateNamedStruct
- TimeWindowing
- CleanupAliases
- EliminateUnions
- EliminateSubqueryAliases
- HandleAnalysisOnlyCommand
- ResolveNewInstances
- ResolveUpCast
- ResolveDeserializer
- ResolveOutputRelation
- ResolveEncodersInUDF
- HandleNullInputsForUDF
- ResolveGenerate
- ExtractGenerator
- GlobalAggregates
- ResolveAggregateFunctions

### Why are the changes needed?

Reduce the number of tree traversals and hence improve the query compilation latency.

### How was this patch tested?

Existing tests.
Performance diff:
<google-sheets-html-origin><style type="text/css"></style>
&nbsp; | Baseline | Experiment | Experiment/Baseline
-- | -- | -- | --
ResolveBinaryArithmetic | 43264874 | 34707117 | 0.80
WindowsSubstitution | 3322996 | 2734192 | 0.82
ResolveAliases | 24859263 | 21359941 | 0.86
ResolveGroupingAnalytics | 39249143 | 25417569 | 0.80
ResolvePivot | 6393408 | 2843314 | 0.44
ResolveOrdinalInOrderByAndGroupBy | 10750806 | 3386715 | 0.32
LookupFunction | 22087384 | 15481294 | 0.70
ResolveSubquery | 1129139340 | 944402323 | 0.84
ResolveSubqueryColumnAliases | 5055038 | 2808210 | 0.56
ApplyCharTypePadding | 76285576 | 63785681 | 0.84
UpdateOuterReferences | 6548321 | 3092539 | 0.47
ResolveCreateNamedStruct | 38111477 | 17350249 | 0.46
TimeWindowing | 41694190 | 3739134 | 0.09
CleanupAliases | 48683506 | 39584921 | 0.81
EliminateUnions | 3405069 | 2372506 | 0.70
EliminateSubqueryAliases | 9626649 | 9518216 | 0.99
HandleAnalysisOnlyCommand | 2562123 | 2661432 | 1.04
ResolveNewInstances | 16208966 | 1982314 | 0.12
ResolveUpCast | 14067843 | 1868615 | 0.13
ResolveDeserializer | 17991103 | 2320308 | 0.13
ResolveOutputRelation | 5815277 | 2088787 | 0.36
ResolveEncodersInUDF | 14182892 | 1045113 | 0.07
HandleNullInputsForUDF | 19850838 | 1329528 | 0.07
ResolveGenerate | 5587345 | 1953192 | 0.35
ExtractGenerator | 120378046 | 3386286 | 0.03
GlobalAggregates | 16510455 | 13553155 | 0.82
ResolveAggregateFunctions | 1041848509 | 828049280 | 0.79

</google-sheets-html-origin>

Closes apache#32686 from sigmod/analyzer.

Authored-by: Yingyi Bu <[email protected]>
Signed-off-by: Gengliang Wang <[email protected]>
  • Loading branch information
sigmod authored and gengliangwang committed Jun 1, 2021
1 parent 73d4f67 commit 1dd0ca2
Show file tree
Hide file tree
Showing 30 changed files with 188 additions and 72 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.collection.mutable
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias, With}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf.{LEGACY_CTE_PRECEDENCE_POLICY, LegacyBehaviorPolicy}

Expand Down Expand Up @@ -130,13 +131,13 @@ object CTESubstitution extends Rule[LogicalPlan] {
* @return the plan where CTE substitution is applied
*/
private def traverseAndSubstituteCTE(plan: LogicalPlan): LogicalPlan = {
plan.resolveOperatorsUp {
plan.resolveOperatorsUpWithPruning(_.containsAnyPattern(UNRESOLVED_RELATION, PLAN_EXPRESSION)) {
case With(child: LogicalPlan, relations) =>
val resolvedCTERelations = resolveCTERelations(relations, isLegacy = false)
substituteCTE(child, resolvedCTERelations)

case other =>
other.transformExpressions {
other.transformExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) {
case e: SubqueryExpression => e.withNewPlan(traverseAndSubstituteCTE(e.plan))
}
}
Expand Down Expand Up @@ -166,13 +167,13 @@ object CTESubstitution extends Rule[LogicalPlan] {
private def substituteCTE(
plan: LogicalPlan,
cteRelations: Seq[(String, LogicalPlan)]): LogicalPlan =
plan resolveOperatorsUp {
plan.resolveOperatorsUpWithPruning(_.containsAnyPattern(UNRESOLVED_RELATION, PLAN_EXPRESSION)) {
case u @ UnresolvedRelation(Seq(table), _, _) =>
cteRelations.find(r => plan.conf.resolver(r._1, table)).map(_._2).getOrElse(u)

case other =>
// This cannot be done in ResolveSubquery because ResolveSubquery does not know the CTE.
other transformExpressions {
other.transformExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) {
case e: SubqueryExpression => e.withNewPlan(substituteCTE(e.plan, cteRelations))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import scala.collection.mutable
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, AttributeSet, NamedExpression, SubqueryExpression}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.AlwaysProcess
import org.apache.spark.sql.catalyst.trees.TreePattern._

/**
* A helper class used to detect duplicate relations fast in `DeduplicateRelations`
Expand All @@ -41,7 +41,7 @@ case class ReferenceEqualPlanWrapper(plan: LogicalPlan) {
object DeduplicateRelations extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = {
renewDuplicatedRelations(mutable.HashSet.empty, plan)._1.resolveOperatorsUpWithPruning(
AlwaysProcess.fn, ruleId) {
_.containsAnyPattern(JOIN, INTERSECT, EXCEPT, UNION, COMMAND), ruleId) {
case p: LogicalPlan if !p.childrenResolved => p
// To resolve duplicate expression IDs for Join.
case j @ Join(left, right, _, _, _) if !j.duplicateResolved =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,16 @@ package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.plans.logical.{DropTable, DropView, LogicalPlan, NoopCommand, UncacheTable}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.COMMAND

/**
* A rule for handling commands when the table or temp view is not resolved.
* These commands support a flag, "ifExists", so that they do not fail when a relation is not
* resolved. If the "ifExists" flag is set to true. the plan is resolved to [[NoopCommand]],
*/
object ResolveCommandsWithIfExists extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning(
_.containsPattern(COMMAND)) {
case DropTable(u: UnresolvedTableOrView, ifExists, _) if ifExists =>
NoopCommand("DROP TABLE", u.multipartIdentifier)
case DropView(u: UnresolvedView, ifExists) if ifExists =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ import scala.collection.mutable
import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression, IntegerLiteral, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, CurrentOrigin}
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_HINT
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf

Expand Down Expand Up @@ -144,7 +145,7 @@ object ResolveHints {
}

def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning(
AlwaysProcess.fn, ruleId) {
_.containsPattern(UNRESOLVED_HINT), ruleId) {
case h: UnresolvedHint if STRATEGY_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) =>
if (h.parameters.isEmpty) {
// If there is no table alias specified, apply the hint on the entire subtree.
Expand Down Expand Up @@ -248,7 +249,7 @@ object ResolveHints {
}

def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning(
AlwaysProcess.fn, ruleId) {
_.containsPattern(UNRESOLVED_HINT), ruleId) {
case hint @ UnresolvedHint(hintName, _, _) => hintName.toUpperCase(Locale.ROOT) match {
case "REPARTITION" =>
createRepartition(shuffle = true, hint)
Expand All @@ -269,7 +270,8 @@ object ResolveHints {

private def hintErrorHandler = conf.hintErrorHandler

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning(
_.containsPattern(UNRESOLVED_HINT)) {
case h: UnresolvedHint =>
hintErrorHandler.hintNotRecognized(h.name, h.parameters)
h.child
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Cast, Literal}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, V2PartitionCommand}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.COMMAND
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.connector.catalog.SupportsPartitionManagement
import org.apache.spark.sql.types._
Expand All @@ -32,7 +33,8 @@ import org.apache.spark.sql.util.PartitioningUtils.{normalizePartitionSpec, requ
*/
object ResolvePartitionSpec extends Rule[LogicalPlan] {

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning(
_.containsPattern(COMMAND)) {
case command: V2PartitionCommand if command.childrenResolved && !command.resolved =>
command.table match {
case r @ ResolvedTable(_, _, table: SupportsPartitionManagement, _) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.{CombineUnions, OptimizeUpdateFields}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.AlwaysProcess
import org.apache.spark.sql.catalyst.trees.TreePattern.UNION
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SchemaUtils
Expand Down Expand Up @@ -250,7 +250,7 @@ object ResolveUnion extends Rule[LogicalPlan] {
}

def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning(
AlwaysProcess.fn, ruleId) {
_.containsPattern(UNION), ruleId) {
case e if !e.childrenResolved => e

case Union(children, byName, allowMissingCol) if byName =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.expressions.{BaseGroupingSets, Expression, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Sort}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.AlwaysProcess
import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.types.IntegerType

/**
Expand All @@ -41,7 +41,7 @@ object SubstituteUnresolvedOrdinals extends Rule[LogicalPlan] {
}

def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning(
AlwaysProcess.fn, ruleId) {
t => t.containsPattern(LITERAL) && t.containsAnyPattern(AGGREGATE, SORT), ruleId) {
case s: Sort if conf.orderByOrdinal && s.order.exists(o => containIntLiteral(o.child)) =>
val newOrders = s.order.map {
case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _, _) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.AlwaysProcess

/**
* Updates nullability of Attributes in a resolved LogicalPlan by using the nullability of
Expand All @@ -32,7 +33,8 @@ import org.apache.spark.sql.catalyst.rules.Rule
*/
object UpdateAttributeNullability extends Rule[LogicalPlan] {

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning(
AlwaysProcess.fn, ruleId) {
// Skip unresolved nodes.
case p if !p.resolved => p
// Skip leaf node, as it has no child and no need to update nullability.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.AlwaysProcess
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.connector.catalog.{CatalogManager, LookupCatalog}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types.DataType
Expand All @@ -34,7 +34,8 @@ import org.apache.spark.sql.types.DataType
case class ResolveHigherOrderFunctions(catalogManager: CatalogManager)
extends Rule[LogicalPlan] with LookupCatalog {

override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressionsWithPruning(
_.containsPattern(LAMBDA_FUNCTION), ruleId) {
case u @ UnresolvedFunction(AsFunctionIdentifier(ident), children, false, filter, ignoreNulls)
if hasLambdaAndResolvedArguments(children) =>
withPosition(u) {
Expand Down Expand Up @@ -91,7 +92,8 @@ object ResolveLambdaVariables extends Rule[LogicalPlan] {
}

override def apply(plan: LogicalPlan): LogicalPlan = {
plan.resolveOperatorsWithPruning(AlwaysProcess.fn, ruleId) {
plan.resolveOperatorsWithPruning(
_.containsAnyPattern(HIGH_ORDER_FUNCTION, LAMBDA_FUNCTION, LAMBDA_VARIABLE), ruleId) {
case q: LogicalPlan =>
q.mapExpressions(resolve(_, Map.empty))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ListQuery, TimeZoneAwareExpression}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.AlwaysProcess
import org.apache.spark.sql.catalyst.trees.TreePattern.{LIST_SUBQUERY, TIME_ZONE_AWARE_EXPRESSION}
import org.apache.spark.sql.types.DataType

/**
Expand All @@ -39,7 +39,9 @@ object ResolveTimeZone extends Rule[LogicalPlan] {
}

override def apply(plan: LogicalPlan): LogicalPlan =
plan.resolveExpressionsWithPruning(AlwaysProcess.fn, ruleId)(transformTimeZoneExprs)
plan.resolveExpressionsWithPruning(
_.containsAnyPattern(LIST_SUBQUERY, TIME_ZONE_AWARE_EXPRESSION), ruleId
)(transformTimeZoneExprs)

def resolveTimeZones(e: Expression): Expression = e.transform(transformTimeZoneExprs)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, UnaryNode}
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
Expand Down Expand Up @@ -57,6 +58,8 @@ case class UnresolvedRelation(
override def output: Seq[Attribute] = Nil

override lazy val resolved = false

final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_RELATION)
}

object UnresolvedRelation {
Expand Down Expand Up @@ -165,6 +168,7 @@ case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Un
override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName)
override def withMetadata(newMetadata: Metadata): Attribute = this
override def withExprId(newExprId: ExprId): UnresolvedAttribute = this
final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_ATTRIBUTE)

override def toString: String = s"'$name"

Expand Down Expand Up @@ -282,6 +286,7 @@ case class UnresolvedFunction(
override def dataType: DataType = throw new UnresolvedException("dataType")
override def nullable: Boolean = throw new UnresolvedException("nullable")
override lazy val resolved = false
final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_FUNCTION)

override def prettyName: String = nameParts.quoted
override def toString: String = {
Expand Down Expand Up @@ -462,6 +467,8 @@ case class MultiAlias(child: Expression, names: Seq[String])

override def newInstance(): NamedExpression = throw new UnresolvedException("newInstance")

final override val nodePatterns: Seq[TreePattern] = Seq(MULTI_ALIAS)

override lazy val resolved = false

override def toString: String = s"$child AS $names"
Expand Down Expand Up @@ -529,6 +536,7 @@ case class UnresolvedAlias(
override def dataType: DataType = throw new UnresolvedException("dataType")
override def name: String = throw new UnresolvedException("name")
override def newInstance(): NamedExpression = throw new UnresolvedException("newInstance")
final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_ALIAS)

override lazy val resolved = false

Expand Down Expand Up @@ -556,6 +564,8 @@ case class UnresolvedSubqueryColumnAliases(

override lazy val resolved = false

final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_SUBQUERY_COLUMN_ALIAS)

override protected def withNewChildInternal(
newChild: LogicalPlan): UnresolvedSubqueryColumnAliases = copy(child = newChild)
}
Expand All @@ -579,6 +589,7 @@ case class UnresolvedDeserializer(deserializer: Expression, inputAttributes: Seq
override def dataType: DataType = throw new UnresolvedException("dataType")
override def nullable: Boolean = throw new UnresolvedException("nullable")
override lazy val resolved = false
final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_DESERIALIZER)

override protected def withNewChildInternal(newChild: Expression): UnresolvedDeserializer =
copy(deserializer = newChild)
Expand Down Expand Up @@ -616,6 +627,7 @@ case class UnresolvedOrdinal(ordinal: Int)
override def dataType: DataType = throw new UnresolvedException("dataType")
override def nullable: Boolean = throw new UnresolvedException("nullable")
override lazy val resolved = false
final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_ORDINAL)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Attribute, LeafExpression, Unevaluable}
import org.apache.spark.sql.catalyst.plans.logical.LeafNode
import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UNRESOLVED_FUNC}
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.connector.catalog.{CatalogPlugin, Identifier, Table, TableCatalog}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
Expand Down Expand Up @@ -95,6 +96,7 @@ case class UnresolvedPartitionSpec(
case class UnresolvedFunc(multipartIdentifier: Seq[String]) extends LeafNode {
override lazy val resolved: Boolean = false
override def output: Seq[Attribute] = Nil
final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_FUNC)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.Cast.{forceNullable, resolvable
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.catalyst.trees.TreePattern.{CAST, TreePattern}
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
Expand Down Expand Up @@ -1826,7 +1826,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))

final override val nodePatterns: Seq[TreePattern] = Seq(CAST)
final override def nodePatternsInternal(): Seq[TreePattern] = Seq(CAST)

override protected val ansiEnabled: Boolean = SQLConf.get.ansiEnabled

Expand Down Expand Up @@ -2030,6 +2030,8 @@ case class UpCast(child: Expression, target: AbstractDataType, walkedTypePath: S
extends UnaryExpression with Unevaluable {
override lazy val resolved = false

final override val nodePatterns: Seq[TreePattern] = Seq(UP_CAST)

def dataType: DataType = target match {
case DecimalType => DecimalType.SYSTEM_DEFAULT
case _ => target.asInstanceOf[DataType]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.api.python.{PythonEvalType, PythonFunction}
import org.apache.spark.sql.catalyst.trees.TreePattern.{PYTHON_UDF, TreePattern}
import org.apache.spark.sql.catalyst.util.toPrettySQL
import org.apache.spark.sql.types.DataType

Expand Down Expand Up @@ -63,6 +64,8 @@ case class PythonUDF(

override def toString: String = s"$name(${children.mkString(", ")})"

final override val nodePatterns: Seq[TreePattern] = Seq(PYTHON_UDF)

lazy val resultAttribute: Attribute = AttributeReference(toPrettySQL(this), dataType, nullable)(
exprId = resultId)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TreePattern.{SCALA_UDF, TreePattern}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, DataType}
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -57,6 +58,8 @@ case class ScalaUDF(

override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic)

final override val nodePatterns: Seq[TreePattern] = Seq(SCALA_UDF)

override def toString: String = s"$name(${children.mkString(", ")})"

override def name: String = udfName.getOrElse("UDF")
Expand Down
Loading

0 comments on commit 1dd0ca2

Please sign in to comment.