Skip to content

Commit

Permalink
[SPARK-28667][SQL] Support InsertInto through the V2SessionCatalog
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR adds support for INSERT INTO through both the SQL and DataFrameWriter APIs through the V2SessionCatalog.

### Why are the changes needed?

This will allow V2 tables to be plugged in through the V2SessionCatalog, and be used seamlessly with existing APIs.

### Does this PR introduce any user-facing change?

No behavior changes.

### How was this patch tested?

Pulled out a lot of tests so that they can be shared across the DataFrameWriter and SQL code paths.

Closes apache#25507 from brkyvz/insertSesh.

Lead-authored-by: Burak Yavuz <[email protected]>
Co-authored-by: Burak Yavuz <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
2 people authored and cloud-fan committed Aug 27, 2019
1 parent 13b1eb6 commit e31aec9
Show file tree
Hide file tree
Showing 9 changed files with 681 additions and 459 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@ import scala.collection.mutable.ArrayBuffer
import scala.util.Random

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalog.v2.{CatalogManager, CatalogNotFoundException, CatalogPlugin, LookupCatalog, TableChange}
import org.apache.spark.sql.catalog.v2._
import org.apache.spark.sql.catalog.v2.expressions.{FieldReference, IdentityTransform}
import org.apache.spark.sql.catalog.v2.utils.CatalogV2Util.loadTable
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.encoders.OuterScopes
Expand All @@ -36,7 +35,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableAlterColumnStatement, AlterTableDropColumnsStatement, AlterTableRenameColumnStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement, DescribeTableStatement, InsertIntoStatement}
import org.apache.spark.sql.catalyst.plans.logical.sql._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
import org.apache.spark.sql.catalyst.util.toPrettySQL
Expand Down Expand Up @@ -641,21 +640,13 @@ class Analyzer(
* [[ResolveRelations]] still resolves v1 tables.
*/
object ResolveTables extends Rule[LogicalPlan] {
import org.apache.spark.sql.catalog.v2.utils.CatalogV2Util._

def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case u @ UnresolvedRelation(AsTemporaryViewIdentifier(ident))
if catalog.isTemporaryTable(ident) =>
u // temporary views take precedence over catalog table names

case u @ UnresolvedRelation(CatalogObjectIdentifier(maybeCatalog, ident)) =>
maybeCatalog.orElse(sessionCatalog)
.flatMap(loadTable(_, ident))
.map {
case unresolved: UnresolvedTable => u
case resolved => DataSourceV2Relation.create(resolved)
}
.getOrElse(u)
case u: UnresolvedRelation =>
val v2TableOpt = lookupV2Relation(u.multipartIdentifier) match {
case scala.Left((_, _, tableOpt)) => tableOpt
case scala.Right(tableOpt) => tableOpt
}
v2TableOpt.map(DataSourceV2Relation.create).getOrElse(u)
}
}

Expand Down Expand Up @@ -770,40 +761,41 @@ class Analyzer(

object ResolveInsertInto extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case i @ InsertIntoStatement(
UnresolvedRelation(CatalogObjectIdentifier(Some(tableCatalog), ident)), _, _, _, _)
if i.query.resolved =>
loadTable(tableCatalog, ident)
.map(DataSourceV2Relation.create)
.map(relation => {
// ifPartitionNotExists is append with validation, but validation is not supported
if (i.ifPartitionNotExists) {
throw new AnalysisException(
s"Cannot write, IF NOT EXISTS is not supported for table: ${relation.table.name}")
}
case i @ InsertIntoStatement(u: UnresolvedRelation, _, _, _, _) if i.query.resolved =>
lookupV2Relation(u.multipartIdentifier) match {
case scala.Left((_, _, Some(v2Table: Table))) =>
resolveV2Insert(i, v2Table)
case scala.Right(Some(v2Table: Table)) =>
resolveV2Insert(i, v2Table)
case _ =>
InsertIntoTable(i.table, i.partitionSpec, i.query, i.overwrite, i.ifPartitionNotExists)
}
}

val partCols = partitionColumnNames(relation.table)
validatePartitionSpec(partCols, i.partitionSpec)
private def resolveV2Insert(i: InsertIntoStatement, table: Table): LogicalPlan = {
val relation = DataSourceV2Relation.create(table)
// ifPartitionNotExists is append with validation, but validation is not supported
if (i.ifPartitionNotExists) {
throw new AnalysisException(
s"Cannot write, IF NOT EXISTS is not supported for table: ${relation.table.name}")
}

val staticPartitions = i.partitionSpec.filter(_._2.isDefined).mapValues(_.get)
val query = addStaticPartitionColumns(relation, i.query, staticPartitions)
val dynamicPartitionOverwrite = partCols.size > staticPartitions.size &&
conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC
val partCols = partitionColumnNames(relation.table)
validatePartitionSpec(partCols, i.partitionSpec)

if (!i.overwrite) {
AppendData.byPosition(relation, query)
} else if (dynamicPartitionOverwrite) {
OverwritePartitionsDynamic.byPosition(relation, query)
} else {
OverwriteByExpression.byPosition(
relation, query, staticDeleteExpression(relation, staticPartitions))
}
})
.getOrElse(i)
val staticPartitions = i.partitionSpec.filter(_._2.isDefined).mapValues(_.get)
val query = addStaticPartitionColumns(relation, i.query, staticPartitions)
val dynamicPartitionOverwrite = partCols.size > staticPartitions.size &&
conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC

case i @ InsertIntoStatement(UnresolvedRelation(AsTableIdentifier(_)), _, _, _, _)
if i.query.resolved =>
InsertIntoTable(i.table, i.partitionSpec, i.query, i.overwrite, i.ifPartitionNotExists)
if (!i.overwrite) {
AppendData.byPosition(relation, query)
} else if (dynamicPartitionOverwrite) {
OverwritePartitionsDynamic.byPosition(relation, query)
} else {
OverwriteByExpression.byPosition(
relation, query, staticDeleteExpression(relation, staticPartitions))
}
}

private def partitionColumnNames(table: Table): Seq[String] = {
Expand Down Expand Up @@ -2773,6 +2765,39 @@ class Analyzer(
}
}
}

/**
* Performs the lookup of DataSourceV2 Tables. The order of resolution is:
* 1. Check if this relation is a temporary table
* 2. Check if it has a catalog identifier. Here we try to load the table. If we find the table,
* we can return the table. The result returned by an explicit catalog will be returned on
* the Left projection of the Either.
* 3. Try resolving the relation using the V2SessionCatalog if that is defined. If the
* V2SessionCatalog returns a V1 table definition (UnresolvedTable), then we return a `None`
* on the right side so that we can fallback to the V1 code paths.
* The basic idea is, if a value is returned on the Left, it means a v2 catalog is defined and
* must be used to resolve the table. If a value is returned on the right, then we can try
* creating a V2 relation if a V2 Table is defined. If it isn't defined, then we should defer
* to V1 code paths.
*/
private def lookupV2Relation(
identifier: Seq[String]
): Either[(CatalogPlugin, Identifier, Option[Table]), Option[Table]] = {
import org.apache.spark.sql.catalog.v2.utils.CatalogV2Util._

identifier match {
case AsTemporaryViewIdentifier(ti) if catalog.isTemporaryTable(ti) =>
scala.Right(None)
case CatalogObjectIdentifier(Some(v2Catalog), ident) =>
scala.Left((v2Catalog, ident, loadTable(v2Catalog, ident)))
case CatalogObjectIdentifier(None, ident) =>
catalogManager.v2SessionCatalog.flatMap(loadTable(_, ident)) match {
case Some(_: UnresolvedTable) => scala.Right(None)
case other => scala.Right(other)
}
case _ => scala.Right(None)
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SchemaUtils

/**
* Throws user facing errors when passed invalid queries that fail to analyze.
Expand Down
20 changes: 17 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -367,10 +367,19 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
)
}

df.sparkSession.sessionState.sqlParser.parseMultipartIdentifier(tableName) match {
val session = df.sparkSession
val provider = DataSource.lookupDataSource(source, session.sessionState.conf)
val canUseV2 = canUseV2Source(session, provider)
val sessionCatalogOpt = session.sessionState.analyzer.sessionCatalog

session.sessionState.sqlParser.parseMultipartIdentifier(tableName) match {
case CatalogObjectIdentifier(Some(catalog), ident) =>
insertInto(catalog, ident)
// TODO(SPARK-28667): Support the V2SessionCatalog

case CatalogObjectIdentifier(None, ident)
if canUseV2 && sessionCatalogOpt.isDefined && ident.namespace().length <= 1 =>
insertInto(sessionCatalogOpt.get, ident)

case AsTableIdentifier(tableIdentifier) =>
insertInto(tableIdentifier)
case other =>
Expand All @@ -382,7 +391,12 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
private def insertInto(catalog: CatalogPlugin, ident: Identifier): Unit = {
import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._

val table = DataSourceV2Relation.create(catalog.asTableCatalog.loadTable(ident))
val table = catalog.asTableCatalog.loadTable(ident) match {
case _: UnresolvedTable =>
return insertInto(TableIdentifier(ident.name(), ident.namespace().headOption))
case t =>
DataSourceV2Relation.create(t)
}

val command = modeForDSV2 match {
case SaveMode.Append =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,43 @@
package org.apache.spark.sql.sources.v2

import java.util
import java.util.concurrent.ConcurrentHashMap

import scala.collection.JavaConverters._

import org.scalatest.BeforeAndAfter

import org.apache.spark.sql.{DataFrame, QueryTest, SaveMode}
import org.apache.spark.sql.catalog.v2.{CatalogPlugin, Identifier}
import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, SaveMode}
import org.apache.spark.sql.catalog.v2.CatalogPlugin
import org.apache.spark.sql.catalog.v2.expressions.Transform
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
import org.apache.spark.sql.execution.datasources.v2.V2SessionCatalog
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{PARTITION_OVERWRITE_MODE, PartitionOverwriteMode, V2_SESSION_CATALOG}
import org.apache.spark.sql.sources.v2.utils.TestV2SessionCatalogBase
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

class DataSourceV2DataFrameSessionCatalogSuite
extends SessionCatalogTest[InMemoryTable, InMemoryTableSessionCatalog] {
extends InsertIntoTests(supportsDynamicOverwrite = true, includeSQLOnlyTests = false)
with SessionCatalogTest[InMemoryTable, InMemoryTableSessionCatalog] {

import testImplicits._

override protected def doInsert(tableName: String, insert: DataFrame, mode: SaveMode): Unit = {
val dfw = insert.write.format(v2Format)
if (mode != null) {
dfw.mode(mode)
}
dfw.insertInto(tableName)
}

override protected def verifyTable(tableName: String, expected: DataFrame): Unit = {
checkAnswer(spark.table(tableName), expected)
checkAnswer(sql(s"SELECT * FROM $tableName"), expected)
checkAnswer(sql(s"SELECT * FROM default.$tableName"), expected)
checkAnswer(sql(s"TABLE $tableName"), expected)
}

override protected val catalogAndNamespace: String = ""

test("saveAsTable: Append mode should not fail if the table already exists " +
"and a same-name temp view exist") {
Expand Down Expand Up @@ -97,21 +114,16 @@ private[v2] trait SessionCatalogTest[T <: Table, Catalog <: TestV2SessionCatalog
protected val catalogClassName: String = classOf[InMemoryTableSessionCatalog].getName

before {
spark.conf.set(SQLConf.V2_SESSION_CATALOG.key, catalogClassName)
spark.conf.set(V2_SESSION_CATALOG.key, catalogClassName)
}

override def afterEach(): Unit = {
super.afterEach()
catalog("session").asInstanceOf[Catalog].clearTables()
spark.conf.set(SQLConf.V2_SESSION_CATALOG.key, classOf[V2SessionCatalog].getName)
spark.conf.set(V2_SESSION_CATALOG.key, classOf[V2SessionCatalog].getName)
}

protected def verifyTable(tableName: String, expected: DataFrame): Unit = {
checkAnswer(spark.table(tableName), expected)
checkAnswer(sql(s"SELECT * FROM $tableName"), expected)
checkAnswer(sql(s"SELECT * FROM default.$tableName"), expected)
checkAnswer(sql(s"TABLE $tableName"), expected)
}
protected def verifyTable(tableName: String, expected: DataFrame): Unit

import testImplicits._

Expand Down
Loading

0 comments on commit e31aec9

Please sign in to comment.