Skip to content

Commit

Permalink
[SPARK-38918][SQL] Nested column pruning should filter out attributes…
Browse files Browse the repository at this point in the history
… that do not belong to the current relation

### What changes were proposed in this pull request?
This PR updates `ProjectionOverSchema`  to use the outputs of the data source relation to filter the attributes in the nested schema pruning. This is needed because the attributes in the schema do not necessarily belong to the current data source relation. For example, if a filter contains a correlated subquery, then the subquery's children can contain attributes from both the inner query and the outer query. Since the `RewriteSubquery` batch happens after early scan pushdown rules, nested schema pruning can wrongly use the inner query's attributes to prune the outer query data schema, thus causing wrong results and unexpected exceptions.

### Why are the changes needed?

To fix a bug in `SchemaPruning`.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Unit test

Closes apache#36216 from allisonwang-db/spark-38918-nested-column-pruning.

Authored-by: allisonwang-db <[email protected]>
Signed-off-by: Liang-Chi Hsieh <[email protected]>
  • Loading branch information
allisonwang-db authored and viirya committed Apr 27, 2022
1 parent 1b7c636 commit 150434b
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,19 @@ import org.apache.spark.sql.types._
* field indexes and field counts of complex type extractors and attributes
* are adjusted to fit the schema. All other expressions are left as-is. This
* class is motivated by columnar nested schema pruning.
*
* @param schema nested column schema
* @param output output attributes of the data source relation. They are used to filter out
* attributes in the schema that do not belong to the current relation.
*/
case class ProjectionOverSchema(schema: StructType) {
case class ProjectionOverSchema(schema: StructType, output: AttributeSet) {
private val fieldNames = schema.fieldNames.toSet

def unapply(expr: Expression): Option[Expression] = getProjection(expr)

private def getProjection(expr: Expression): Option[Expression] =
expr match {
case a: AttributeReference if fieldNames.contains(a.name) =>
case a: AttributeReference if fieldNames.contains(a.name) && output.contains(a) =>
Some(a.copy(dataType = schema(a.name).dataType)(a.exprId, a.qualifier))
case GetArrayItem(child, arrayItemOrdinal, failOnError) =>
getProjection(child).map {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
override protected val excludedOnceBatches: Set[String] =
Set(
"PartitionPruning",
"RewriteSubquery",
"Extract Python UDFs")

protected def fixedPoint =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ object ObjectSerializerPruning extends Rule[LogicalPlan] {
}

// Builds new projection.
val projectionOverSchema = ProjectionOverSchema(prunedSchema)
val projectionOverSchema = ProjectionOverSchema(prunedSchema, AttributeSet(s.output))
val newProjects = p.projectList.map(_.transformDown {
case projectionOverSchema(expr) => expr
}).map { case expr: NamedExpression => expr }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ object SchemaPruning extends Rule[LogicalPlan] {
if (countLeaves(hadoopFsRelation.dataSchema) > countLeaves(prunedDataSchema) ||
countLeaves(metadataSchema) > countLeaves(prunedMetadataSchema)) {
val prunedRelation = leafNodeBuilder(prunedDataSchema, prunedMetadataSchema)
val projectionOverSchema =
ProjectionOverSchema(prunedDataSchema.merge(prunedMetadataSchema))
val projectionOverSchema = ProjectionOverSchema(
prunedDataSchema.merge(prunedMetadataSchema), AttributeSet(relation.output))
Some(buildNewProjection(projects, normalizedProjects, normalizedFilters,
prunedRelation, projectionOverSchema))
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2

import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.{Alias, AliasHelper, And, Attribute, AttributeReference, Cast, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, AliasHelper, And, Attribute, AttributeReference, AttributeSet, Cast, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.optimizer.CollapseProject
import org.apache.spark.sql.catalyst.planning.ScanOperation
Expand Down Expand Up @@ -320,7 +319,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit

val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output)

val projectionOverSchema = ProjectionOverSchema(output.toStructType)
val projectionOverSchema =
ProjectionOverSchema(output.toStructType, AttributeSet(output))
val projectionFunc = (expr: Expression) => expr transformDown {
case projectionOverSchema(newExpr) => newExpr
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,15 @@ abstract class SchemaPruningSuite
override protected def sparkConf: SparkConf =
super.sparkConf.set(SQLConf.ANSI_STRICT_INDEX_OPERATOR.key, "false")

case class Employee(id: Int, name: FullName, employer: Company)

val janeDoe = FullName("Jane", "X.", "Doe")
val johnDoe = FullName("John", "Y.", "Doe")
val susanSmith = FullName("Susan", "Z.", "Smith")

val employer = Employer(0, Company("abc", "123 Business Street"))
val company = Company("abc", "123 Business Street")

val employer = Employer(0, company)
val employerWithNullCompany = Employer(1, null)
val employerWithNullCompany2 = Employer(2, null)

Expand All @@ -81,6 +85,8 @@ abstract class SchemaPruningSuite
Department(1, "Marketing", 1, employerWithNullCompany) ::
Department(2, "Operation", 4, employerWithNullCompany2) :: Nil

val employees = Employee(0, janeDoe, company) :: Employee(1, johnDoe, company) :: Nil

case class Name(first: String, last: String)
case class BriefContact(id: Int, name: Name, address: String)

Expand Down Expand Up @@ -621,6 +627,26 @@ abstract class SchemaPruningSuite
}
}

testSchemaPruning("SPARK-38918: nested schema pruning with correlated subqueries") {
withContacts {
withEmployees {
val query = sql(
"""
|select count(*)
|from contacts c
|where not exists (select null from employees e where e.name.first = c.name.first
| and e.employer.name = c.employer.company.name)
|""".stripMargin)
checkScan(query,
"struct<name:struct<first:string,middle:string,last:string>," +
"employer:struct<id:int,company:struct<name:string,address:string>>>",
"struct<name:struct<first:string,middle:string,last:string>," +
"employer:struct<name:string,address:string>>")
checkAnswer(query, Row(3))
}
}
}

protected def testSchemaPruning(testName: String)(testThunk: => Unit): Unit = {
test(s"Spark vectorized reader - without partition data column - $testName") {
withSQLConf(vectorizedReaderEnabledKey -> "true") {
Expand Down Expand Up @@ -701,6 +727,23 @@ abstract class SchemaPruningSuite
}
}

private def withEmployees(testThunk: => Unit): Unit = {
withTempPath { dir =>
val path = dir.getCanonicalPath

makeDataSourceFile(employees, new File(path + "/employees"))

// Providing user specified schema. Inferred schema from different data sources might
// be different.
val schema = "`id` INT,`name` STRUCT<`first`: STRING, `middle`: STRING, `last`: STRING>, " +
"`employer` STRUCT<`name`: STRING, `address`: STRING>"
spark.read.format(dataSourceName).schema(schema).load(path + "/employees")
.createOrReplaceTempView("employees")

testThunk
}
}

case class MixedCaseColumn(a: String, B: Int)
case class MixedCase(id: Int, CoL1: String, coL2: MixedCaseColumn)

Expand Down

0 comments on commit 150434b

Please sign in to comment.