Skip to content

Commit

Permalink
[SPARK-43529][SQL][FOLLOWUP] Code cleanup in UnresolvedTableSpec and …
Browse files Browse the repository at this point in the history
…related plans

### What changes were proposed in this pull request?

Follow-up of apache#41191 to clean up the code in UnresolvedTableSpec and related plans:
* Rename `OptionsListExpressions` as `OptionList`
* Rename `trait TableSpec` as `TableSpecBase`
* Rename `ResolvedTableSpec` as `TableSpec`, make sure all the physical plans are using `TableSpec` instead of `TableSpecBase`.
* Move option list expressions to UnresolvedTableSpec, so that all the specs are in one class.
* Make UnaryExpression an `UnaryExpression`, so that transforming with `mapExpressions` will transform it and the option list expressions in its child
* Restore the signatures of class `CreateTable`, `CreateTableAsSelect`, `ReplaceTable` and `ReplaceTableAsSelect`

### Why are the changes needed?

Make the code implementation simpler
### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Existing tests

Closes apache#41549 from gengliangwang/optionsFollowUp.

Authored-by: Gengliang Wang <[email protected]>
Signed-off-by: Gengliang Wang <[email protected]>
  • Loading branch information
gengliangwang committed Jun 13, 2023
1 parent dfd40a4 commit 7e94f2a
Show file tree
Hide file tree
Showing 14 changed files with 148 additions and 132 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,23 @@ object ResolveTableSpec extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = {
plan.resolveOperatorsWithPruning(_.containsAnyPattern(COMMAND), ruleId) {
case t: CreateTable =>
resolveTableSpec(t, t.tableSpec, t.optionsListExpressions, s => t.copy(tableSpec = s))
resolveTableSpec(t, t.tableSpec, s => t.copy(tableSpec = s))
case t: CreateTableAsSelect =>
resolveTableSpec(t, t.tableSpec, t.optionsListExpressions, s => t.copy(tableSpec = s))
resolveTableSpec(t, t.tableSpec, s => t.copy(tableSpec = s))
case t: ReplaceTable =>
resolveTableSpec(t, t.tableSpec, t.optionsListExpressions, s => t.copy(tableSpec = s))
resolveTableSpec(t, t.tableSpec, s => t.copy(tableSpec = s))
case t: ReplaceTableAsSelect =>
resolveTableSpec(t, t.tableSpec, t.optionsListExpressions, s => t.copy(tableSpec = s))
resolveTableSpec(t, t.tableSpec, s => t.copy(tableSpec = s))
}
}

/** Helper method to resolve the table specification within a logical plan. */
private def resolveTableSpec(
input: LogicalPlan, tableSpec: TableSpec, optionsListExpressions: OptionsListExpressions,
withNewSpec: TableSpec => LogicalPlan): LogicalPlan = tableSpec match {
case u: UnresolvedTableSpec if optionsListExpressions.allOptionsResolved =>
val newOptions: Seq[(String, String)] = optionsListExpressions.options.map {
input: LogicalPlan,
tableSpec: TableSpecBase,
withNewSpec: TableSpecBase => LogicalPlan): LogicalPlan = tableSpec match {
case u: UnresolvedTableSpec if u.optionExpression.resolved =>
val newOptions: Seq[(String, String)] = u.optionExpression.options.map {
case (key: String, null) =>
(key, null)
case (key: String, value: Expression) =>
Expand All @@ -75,7 +76,7 @@ object ResolveTableSpec extends Rule[LogicalPlan] {
}
(key, newValue)
}
val newTableSpec = ResolvedTableSpec(
val newTableSpec = TableSpec(
properties = u.properties,
provider = u.provider,
options = newOptions.toMap,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3348,13 +3348,13 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit
* specified.
*/
override def visitExpressionPropertyList(
ctx: ExpressionPropertyListContext): OptionsListExpressions = {
ctx: ExpressionPropertyListContext): OptionList = {
val options = ctx.expressionProperty.asScala.map { property =>
val key: String = visitPropertyKey(property.key)
val value: Expression = Option(property.value).map(expression).getOrElse(null)
key -> value
}.toSeq
OptionsListExpressions(options)
OptionList(options)
}

override def visitStringLit(ctx: StringLitContext): Token = {
Expand Down Expand Up @@ -3391,7 +3391,7 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit
*/
type TableClauses = (
Seq[Transform], Seq[StructField], Option[BucketSpec], Map[String, String],
OptionsListExpressions, Option[String], Option[String], Option[SerdeInfo])
OptionList, Option[String], Option[String], Option[SerdeInfo])

/**
* Validate a create table statement and return the [[TableIdentifier]].
Expand Down Expand Up @@ -3686,8 +3686,8 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit

def cleanTableOptions(
ctx: ParserRuleContext,
options: OptionsListExpressions,
location: Option[String]): (OptionsListExpressions, Option[String]) = {
options: OptionList,
location: Option[String]): (OptionList, Option[String]) = {
var path = location
val filtered = cleanTableProperties(ctx, options.options.toMap).filter {
case (key, value) if key.equalsIgnoreCase("path") =>
Expand All @@ -3705,7 +3705,7 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit
false
case _ => true
}
(OptionsListExpressions(filtered.toSeq), path)
(OptionList(filtered.toSeq), path)
}

/**
Expand Down Expand Up @@ -3864,7 +3864,7 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit
val properties = Option(ctx.tableProps).map(visitPropertyKeyValues).getOrElse(Map.empty)
val cleanedProperties = cleanTableProperties(ctx, properties)
val options = Option(ctx.options).map(visitExpressionPropertyList)
.getOrElse(OptionsListExpressions(Seq.empty))
.getOrElse(OptionList(Seq.empty))
val location = visitLocationSpecList(ctx.locationSpec())
val (cleanedOptions, newLocation) = cleanTableOptions(ctx, options, location)
val comment = visitCommentSpecList(ctx.commentSpec())
Expand Down Expand Up @@ -3959,7 +3959,7 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit

val partitioning =
partitionExpressions(partTransforms, partCols, ctx) ++ bucketSpec.map(_.asTransform)
val tableSpec = UnresolvedTableSpec(properties, provider, location, comment,
val tableSpec = UnresolvedTableSpec(properties, provider, options, location, comment,
serdeInfo, external)

Option(ctx.query).map(plan) match {
Expand All @@ -3976,15 +3976,14 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit

case Some(query) =>
CreateTableAsSelect(withIdentClause(identifierContext, UnresolvedIdentifier(_)),
partitioning, query, tableSpec, Map.empty, ifNotExists, optionsListExpressions = options)
partitioning, query, tableSpec, Map.empty, ifNotExists)

case _ =>
// Note: table schema includes both the table columns list and the partition columns
// with data type.
val schema = StructType(columns ++ partCols)
CreateTable(withIdentClause(identifierContext, UnresolvedIdentifier(_)),
schema, partitioning, tableSpec, ignoreIfExists = ifNotExists,
optionsListExpressions = options)
schema, partitioning, tableSpec, ignoreIfExists = ifNotExists)
}
}

Expand Down Expand Up @@ -4029,7 +4028,7 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit

val partitioning =
partitionExpressions(partTransforms, partCols, ctx) ++ bucketSpec.map(_.asTransform)
val tableSpec = UnresolvedTableSpec(properties, provider, location, comment,
val tableSpec = UnresolvedTableSpec(properties, provider, options, location, comment,
serdeInfo, external = false)

Option(ctx.query).map(plan) match {
Expand All @@ -4047,16 +4046,15 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit
case Some(query) =>
ReplaceTableAsSelect(
withIdentClause(ctx.replaceTableHeader.identifierReference(), UnresolvedIdentifier(_)),
partitioning, query, tableSpec, writeOptions = Map.empty, orCreate = orCreate,
optionsListExpressions = options)
partitioning, query, tableSpec, writeOptions = Map.empty, orCreate = orCreate)

case _ =>
// Note: table schema includes both the table columns list and the partition columns
// with data type.
val schema = StructType(columns ++ partCols)
ReplaceTable(
withIdentClause(ctx.replaceTableHeader.identifierReference(), UnresolvedIdentifier(_)),
schema, partitioning, tableSpec, orCreate = orCreate, optionsListExpressions = options)
schema, partitioning, tableSpec, orCreate = orCreate)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, AssignmentUtils, EliminateSubqueryAliases, FieldName, NamedRelation, PartitionSpec, ResolvedIdentifier, UnresolvedException}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.catalog.FunctionResource
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, Expression, MetadataAttribute, NamedExpression, Unevaluable, V2ExpressionUtils}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, Expression, MetadataAttribute, NamedExpression, UnaryExpression, Unevaluable, V2ExpressionUtils}
import org.apache.spark.sql.catalyst.plans.DescribeCommandSchema
import org.apache.spark.sql.catalyst.trees.BinaryLike
import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, RowDeltaUtils, WriteDeltaProjections}
Expand All @@ -31,6 +31,7 @@ import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.write.{DeltaWrite, RowLevelOperation, RowLevelOperationTable, SupportsDelta, Write}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.types.{BooleanType, DataType, IntegerType, MapType, MetadataBuilder, StringType, StructField, StructType}
import org.apache.spark.util.Utils

// For v2 DML commands, it may end up with the v1 fallback code path and need to build a DataFrame
// which is required by the DS v1 API. We need to keep the analyzed input query plan to build
Expand Down Expand Up @@ -444,9 +445,8 @@ case class CreateTable(
name: LogicalPlan,
tableSchema: StructType,
partitioning: Seq[Transform],
tableSpec: TableSpec,
ignoreIfExists: Boolean,
optionsListExpressions: OptionsListExpressions = OptionsListExpressions(Seq.empty))
tableSpec: TableSpecBase,
ignoreIfExists: Boolean)
extends UnaryCommand with V2CreateTablePlan {

override def child: LogicalPlan = name
Expand All @@ -466,11 +466,10 @@ case class CreateTableAsSelect(
name: LogicalPlan,
partitioning: Seq[Transform],
query: LogicalPlan,
tableSpec: TableSpec,
tableSpec: TableSpecBase,
writeOptions: Map[String, String],
ignoreIfExists: Boolean,
isAnalyzed: Boolean = false,
optionsListExpressions: OptionsListExpressions = OptionsListExpressions(Seq.empty))
isAnalyzed: Boolean = false)
extends V2CreateTableAsSelectPlan {

override def markAsAnalyzed(ac: AnalysisContext): LogicalPlan = copy(isAnalyzed = true)
Expand Down Expand Up @@ -498,9 +497,8 @@ case class ReplaceTable(
name: LogicalPlan,
tableSchema: StructType,
partitioning: Seq[Transform],
tableSpec: TableSpec,
orCreate: Boolean,
optionsListExpressions: OptionsListExpressions = OptionsListExpressions(Seq.empty))
tableSpec: TableSpecBase,
orCreate: Boolean)
extends UnaryCommand with V2CreateTablePlan {

override def child: LogicalPlan = name
Expand All @@ -523,11 +521,10 @@ case class ReplaceTableAsSelect(
name: LogicalPlan,
partitioning: Seq[Transform],
query: LogicalPlan,
tableSpec: TableSpec,
tableSpec: TableSpecBase,
writeOptions: Map[String, String],
orCreate: Boolean,
isAnalyzed: Boolean = false,
optionsListExpressions: OptionsListExpressions = OptionsListExpressions(Seq.empty))
isAnalyzed: Boolean = false)
extends V2CreateTableAsSelectPlan {

override def markAsAnalyzed(ac: AnalysisContext): LogicalPlan = copy(isAnalyzed = true)
Expand Down Expand Up @@ -1388,25 +1385,34 @@ case class DropIndex(
copy(table = newChild)
}

trait TableSpec {
trait TableSpecBase {
def properties: Map[String, String]
def provider: Option[String]
def location: Option[String]
def comment: Option[String]
def serde: Option[SerdeInfo]
def external: Boolean
def withNewLocation(newLocation: Option[String]): TableSpec
}

case class UnresolvedTableSpec(
properties: Map[String, String],
provider: Option[String],
optionExpression: OptionList,
location: Option[String],
comment: Option[String],
serde: Option[SerdeInfo],
external: Boolean) extends TableSpec {
override def withNewLocation(loc: Option[String]): TableSpec = {
UnresolvedTableSpec(properties, provider, loc, comment, serde, external)
external: Boolean) extends UnaryExpression with Unevaluable with TableSpecBase {

override def dataType: DataType =
throw new UnsupportedOperationException("UnresolvedTableSpec doesn't have a data type")

override def child: Expression = optionExpression

override protected def withNewChildInternal(newChild: Expression): Expression =
this.copy(optionExpression = newChild.asInstanceOf[OptionList])

override def simpleString(maxFields: Int): String = {
this.copy(properties = Utils.redact(properties).toMap).toString
}
}

Expand All @@ -1415,11 +1421,12 @@ case class UnresolvedTableSpec(
* UnresolvedTableSpec lives. We use a separate object so that tree traversals in analyzer rules can
* descend into the child expressions naturally without extra treatment.
*/
case class OptionsListExpressions(options: Seq[(String, Expression)])
case class OptionList(options: Seq[(String, Expression)])
extends Expression with Unevaluable {
override def nullable: Boolean = true
override def dataType: DataType = MapType(StringType, StringType)
override def children: Seq[Expression] = options.map(_._2)
override lazy val resolved: Boolean = options.map(_._2).forall(_.resolved)

override protected def withNewChildrenInternal(
newChildren: IndexedSeq[Expression]): Expression = {
Expand All @@ -1428,21 +1435,19 @@ case class OptionsListExpressions(options: Seq[(String, Expression)])
case ((key: String, _), newChild: Expression) =>
(key, newChild)
}
OptionsListExpressions(newOptions)
OptionList(newOptions)
}

lazy val allOptionsResolved: Boolean = options.map(_._2).forall(_.resolved)
}

case class ResolvedTableSpec(
case class TableSpec(
properties: Map[String, String],
provider: Option[String],
options: Map[String, String],
location: Option[String],
comment: Option[String],
serde: Option[SerdeInfo],
external: Boolean) extends TableSpec {
override def withNewLocation(newLocation: Option[String]): TableSpec = {
ResolvedTableSpec(properties, provider, options, newLocation, comment, serde, external)
external: Boolean) extends TableSpecBase {
def withNewLocation(newLocation: Option[String]): TableSpec = {
TableSpec(properties, provider, options, newLocation, comment, serde, external)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection._
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, FunctionResource}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.logical.{ResolvedTableSpec, UnresolvedTableSpec}
import org.apache.spark.sql.catalyst.plans.logical.TableSpec
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning}
import org.apache.spark.sql.catalyst.rules.RuleId
import org.apache.spark.sql.catalyst.rules.RuleIdCollection
Expand Down Expand Up @@ -927,11 +927,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre
redactMapString(map.asCaseSensitiveMap().asScala, maxFields)
case map: Map[_, _] =>
redactMapString(map, maxFields)
case t: ResolvedTableSpec =>
case t: TableSpec =>
t.copy(properties = Utils.redact(t.properties).toMap,
options = Utils.redact(t.options).toMap) :: Nil
case t: UnresolvedTableSpec =>
t.copy(properties = Utils.redact(t.properties).toMap) :: Nil
case table: CatalogTable =>
stringArgsForCatalogTable(table)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.{AsOfTimestamp, AsOfVersion, NamedRelation, NoSuchDatabaseException, NoSuchFunctionException, NoSuchNamespaceException, NoSuchTableException, TimeTravelSpec}
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.logical.{ResolvedTableSpec, SerdeInfo, TableSpec}
import org.apache.spark.sql.catalyst.plans.logical.{SerdeInfo, TableSpec}
import org.apache.spark.sql.catalyst.util.GeneratedColumn
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._
import org.apache.spark.sql.connector.catalog.TableChange._
Expand Down Expand Up @@ -376,7 +376,7 @@ private[sql] object CatalogV2Util {

def convertTableProperties(t: TableSpec): Map[String, String] = {
val props = convertTableProperties(
t.properties, t.asInstanceOf[ResolvedTableSpec].options, t.serde, t.location, t.comment,
t.properties, t.options, t.serde, t.location, t.comment,
t.provider, t.external)
withDefaultOwnership(props)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@ package org.apache.spark.sql.catalyst.analysis
import java.util

import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelect, LeafNode, UnresolvedTableSpec}
import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelect, LeafNode, OptionList, UnresolvedTableSpec}
import org.apache.spark.sql.connector.catalog.{InMemoryTableCatalog, Table, TableCapability, TableCatalog}
import org.apache.spark.sql.connector.expressions.Expressions
import org.apache.spark.sql.types.{DoubleType, LongType, StringType, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap

class CreateTablePartitioningValidationSuite extends AnalysisTest {
val tableSpec = UnresolvedTableSpec(Map.empty, None, None, None, None, false)
val tableSpec =
UnresolvedTableSpec(Map.empty, None, OptionList(Seq.empty), None, None, None, false)
test("CreateTableAsSelect: fail missing top-level column") {
val plan = CreateTableAsSelect(
UnresolvedIdentifier(Array("table_name")),
Expand Down
Loading

0 comments on commit 7e94f2a

Please sign in to comment.