Skip to content

Commit

Permalink
[SPARK-26744][SQL] Support schema validation in FileDataSourceV2 fram…
Browse files Browse the repository at this point in the history
…ework

## What changes were proposed in this pull request?

The file source has a schema validation feature, which validates 2 schemas:
1. the user-specified schema when reading.
2. the schema of input data when writing.

If a file source doesn't support the schema, we can fail the query earlier.

This PR is to implement the same feature  in the `FileDataSourceV2` framework. Comparing to `FileFormat`, `FileDataSourceV2` has multiple layers. The API is added in two places:
1. Read path: the table schema is determined in `TableProvider.getTable`. The actual read schema can be a subset of the table schema.  This PR proposes to validate the actual read schema in  `FileScan`.
2.  Write path: validate the actual output schema in `FileWriteBuilder`.

## How was this patch tested?

Unit test

Closes apache#23714 from gengliangwang/schemaValidationV2.

Authored-by: Gengliang Wang <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
gengliangwang authored and cloud-fan committed Feb 16, 2019
1 parent 4cabab8 commit 4dce45a
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,39 @@ package org.apache.spark.sql.execution.datasources.v2

import org.apache.hadoop.fs.Path

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.execution.PartitionedFileUtil
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources.v2.reader.{Batch, InputPartition, Scan}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{DataType, StructType}

abstract class FileScan(
sparkSession: SparkSession,
fileIndex: PartitioningAwareFileIndex) extends Scan with Batch {
fileIndex: PartitioningAwareFileIndex,
readSchema: StructType) extends Scan with Batch {
/**
* Returns whether a file with `path` could be split or not.
*/
def isSplitable(path: Path): Boolean = {
false
}

/**
* Returns whether this format supports the given [[DataType]] in write path.
* By default all data types are supported.
*/
def supportsDataType(dataType: DataType): Boolean = true

/**
* The string that represents the format that this data source provider uses. This is
* overridden by children to provide a nice alias for the data source. For example:
*
* {{{
* override def formatName(): String = "ORC"
* }}}
*/
def formatName: String

protected def partitions: Seq[FilePartition] = {
val selectedPartitions = fileIndex.listFiles(Seq.empty, Seq.empty)
val maxSplitBytes = FilePartition.maxSplitBytes(sparkSession, selectedPartitions)
Expand All @@ -57,5 +74,13 @@ abstract class FileScan(
partitions.toArray
}

override def toBatch: Batch = this
override def toBatch: Batch = {
readSchema.foreach { field =>
if (!supportsDataType(field.dataType)) {
throw new AnalysisException(
s"$formatName data source does not support ${field.dataType.catalogString} data type.")
}
}
this
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.v2.DataSourceOptions
import org.apache.spark.sql.sources.v2.writer.{BatchWrite, SupportsSaveMode, WriteBuilder}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.util.SerializableConfiguration

abstract class FileWriteBuilder(options: DataSourceOptions)
Expand Down Expand Up @@ -104,12 +104,34 @@ abstract class FileWriteBuilder(options: DataSourceOptions)
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory

/**
* Returns whether this format supports the given [[DataType]] in write path.
* By default all data types are supported.
*/
def supportsDataType(dataType: DataType): Boolean = true

/**
* The string that represents the format that this data source provider uses. This is
* overridden by children to provide a nice alias for the data source. For example:
*
* {{{
* override def formatName(): String = "ORC"
* }}}
*/
def formatName: String

private def validateInputs(): Unit = {
assert(schema != null, "Missing input data schema")
assert(queryId != null, "Missing query ID")
assert(mode != null, "Missing save mode")
assert(options.paths().length == 1)
DataSource.validateSchema(schema)
schema.foreach { field =>
if (!supportsDataType(field.dataType)) {
throw new AnalysisException(
s"$formatName data source does not support ${field.dataType.catalogString} data type.")
}
}
}

private def getJobInstance(hadoopConf: Configuration, path: Path): Job = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
import org.apache.spark.sql.execution.datasources.v2._
import org.apache.spark.sql.sources.v2.{DataSourceOptions, Table}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types._

class OrcDataSourceV2 extends FileDataSourceV2 {

Expand All @@ -42,3 +42,20 @@ class OrcDataSourceV2 extends FileDataSourceV2 {
OrcTable(tableName, sparkSession, options, Some(schema))
}
}

object OrcDataSourceV2 {
def supportsDataType(dataType: DataType): Boolean = dataType match {
case _: AtomicType => true

case st: StructType => st.forall { f => supportsDataType(f.dataType) }

case ArrayType(elementType, _) => supportsDataType(elementType)

case MapType(keyType, valueType, _) =>
supportsDataType(keyType) && supportsDataType(valueType)

case udt: UserDefinedType[_] => supportsDataType(udt.sqlType)

case _ => false
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
import org.apache.spark.sql.execution.datasources.v2.FileScan
import org.apache.spark.sql.sources.v2.reader.PartitionReaderFactory
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.util.SerializableConfiguration

case class OrcScan(
sparkSession: SparkSession,
hadoopConf: Configuration,
fileIndex: PartitioningAwareFileIndex,
dataSchema: StructType,
readSchema: StructType) extends FileScan(sparkSession, fileIndex) {
readSchema: StructType) extends FileScan(sparkSession, fileIndex, readSchema) {
override def isSplitable(path: Path): Boolean = true

override def createReaderFactory(): PartitionReaderFactory = {
Expand All @@ -40,4 +40,10 @@ case class OrcScan(
OrcPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf,
dataSchema, fileIndex.partitionSchema, readSchema)
}

override def supportsDataType(dataType: DataType): Boolean = {
OrcDataSourceV2.supportsDataType(dataType)
}

override def formatName: String = "ORC"
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,10 @@ class OrcWriteBuilder(options: DataSourceOptions) extends FileWriteBuilder(optio
}
}
}

override def supportsDataType(dataType: DataType): Boolean = {
OrcDataSourceV2.supportsDataType(dataType)
}

override def formatName: String = "ORC"
}
Original file line number Diff line number Diff line change
Expand Up @@ -329,83 +329,97 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo
test("SPARK-24204 error handling for unsupported Interval data types - csv, json, parquet, orc") {
withTempDir { dir =>
val tempDir = new File(dir, "files").getCanonicalPath
// TODO(SPARK-26744): support data type validating in V2 data source, and test V2 as well.
withSQLConf(SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> "orc") {
// write path
Seq("csv", "json", "parquet", "orc").foreach { format =>
var msg = intercept[AnalysisException] {
sql("select interval 1 days").write.format(format).mode("overwrite").save(tempDir)
}.getMessage
assert(msg.contains("Cannot save interval data type into external storage."))

msg = intercept[AnalysisException] {
spark.udf.register("testType", () => new IntervalData())
sql("select testType()").write.format(format).mode("overwrite").save(tempDir)
}.getMessage
assert(msg.toLowerCase(Locale.ROOT)
.contains(s"$format data source does not support calendarinterval data type."))
Seq(true, false).foreach { useV1 =>
val useV1List = if (useV1) {
"orc"
} else {
""
}
def errorMessage(format: String, isWrite: Boolean): String = {
if (isWrite && (useV1 || format != "orc")) {
"cannot save interval data type into external storage."
} else {
s"$format data source does not support calendarinterval data type."
}
}

withSQLConf(SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> useV1List) {
// write path
Seq("csv", "json", "parquet", "orc").foreach { format =>
var msg = intercept[AnalysisException] {
sql("select interval 1 days").write.format(format).mode("overwrite").save(tempDir)
}.getMessage
assert(msg.toLowerCase(Locale.ROOT).contains(errorMessage(format, true)))
}

// read path
Seq("parquet", "csv").foreach { format =>
var msg = intercept[AnalysisException] {
val schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil)
spark.range(1).write.format(format).mode("overwrite").save(tempDir)
spark.read.schema(schema).format(format).load(tempDir).collect()
}.getMessage
assert(msg.toLowerCase(Locale.ROOT)
.contains(s"$format data source does not support calendarinterval data type."))

msg = intercept[AnalysisException] {
val schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil)
spark.range(1).write.format(format).mode("overwrite").save(tempDir)
spark.read.schema(schema).format(format).load(tempDir).collect()
}.getMessage
assert(msg.toLowerCase(Locale.ROOT)
.contains(s"$format data source does not support calendarinterval data type."))
// read path
Seq("parquet", "csv").foreach { format =>
var msg = intercept[AnalysisException] {
val schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil)
spark.range(1).write.format(format).mode("overwrite").save(tempDir)
spark.read.schema(schema).format(format).load(tempDir).collect()
}.getMessage
assert(msg.toLowerCase(Locale.ROOT).contains(errorMessage(format, false)))

msg = intercept[AnalysisException] {
val schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil)
spark.range(1).write.format(format).mode("overwrite").save(tempDir)
spark.read.schema(schema).format(format).load(tempDir).collect()
}.getMessage
assert(msg.toLowerCase(Locale.ROOT).contains(errorMessage(format, false)))
}
}
}
}
}

test("SPARK-24204 error handling for unsupported Null data types - csv, parquet, orc") {
// TODO(SPARK-26744): support data type validating in V2 data source, and test V2 as well.
withSQLConf(SQLConf.USE_V1_SOURCE_READER_LIST.key -> "orc",
SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> "orc") {
withTempDir { dir =>
val tempDir = new File(dir, "files").getCanonicalPath

Seq("parquet", "csv", "orc").foreach { format =>
// write path
var msg = intercept[AnalysisException] {
sql("select null").write.format(format).mode("overwrite").save(tempDir)
}.getMessage
assert(msg.toLowerCase(Locale.ROOT)
.contains(s"$format data source does not support null data type."))

msg = intercept[AnalysisException] {
spark.udf.register("testType", () => new NullData())
sql("select testType()").write.format(format).mode("overwrite").save(tempDir)
}.getMessage
assert(msg.toLowerCase(Locale.ROOT)
.contains(s"$format data source does not support null data type."))

// read path
msg = intercept[AnalysisException] {
val schema = StructType(StructField("a", NullType, true) :: Nil)
spark.range(1).write.format(format).mode("overwrite").save(tempDir)
spark.read.schema(schema).format(format).load(tempDir).collect()
}.getMessage
assert(msg.toLowerCase(Locale.ROOT)
.contains(s"$format data source does not support null data type."))

msg = intercept[AnalysisException] {
val schema = StructType(StructField("a", new NullUDT(), true) :: Nil)
spark.range(1).write.format(format).mode("overwrite").save(tempDir)
spark.read.schema(schema).format(format).load(tempDir).collect()
}.getMessage
assert(msg.toLowerCase(Locale.ROOT)
.contains(s"$format data source does not support null data type."))
Seq(true, false).foreach { useV1 =>
val useV1List = if (useV1) {
"orc"
} else {
""
}
def errorMessage(format: String): String = {
s"$format data source does not support null data type."
}
withSQLConf(SQLConf.USE_V1_SOURCE_READER_LIST.key -> useV1List,
SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> useV1List) {
withTempDir { dir =>
val tempDir = new File(dir, "files").getCanonicalPath

Seq("parquet", "csv", "orc").foreach { format =>
// write path
var msg = intercept[AnalysisException] {
sql("select null").write.format(format).mode("overwrite").save(tempDir)
}.getMessage
assert(msg.toLowerCase(Locale.ROOT)
.contains(errorMessage(format)))

msg = intercept[AnalysisException] {
spark.udf.register("testType", () => new NullData())
sql("select testType()").write.format(format).mode("overwrite").save(tempDir)
}.getMessage
assert(msg.toLowerCase(Locale.ROOT)
.contains(errorMessage(format)))

// read path
msg = intercept[AnalysisException] {
val schema = StructType(StructField("a", NullType, true) :: Nil)
spark.range(1).write.format(format).mode("overwrite").save(tempDir)
spark.read.schema(schema).format(format).load(tempDir).collect()
}.getMessage
assert(msg.toLowerCase(Locale.ROOT)
.contains(errorMessage(format)))

msg = intercept[AnalysisException] {
val schema = StructType(StructField("a", new NullUDT(), true) :: Nil)
spark.range(1).write.format(format).mode("overwrite").save(tempDir)
spark.read.schema(schema).format(format).load(tempDir).collect()
}.getMessage
assert(msg.toLowerCase(Locale.ROOT)
.contains(errorMessage(format)))
}
}
}
}
Expand Down

0 comments on commit 4dce45a

Please sign in to comment.