Skip to content

Commit

Permalink
[SPARK-26495][SQL] Simplify the SelectedField extractor.
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?
The current `SelectedField` extractor is somewhat complicated and it seems to be handling cases that should be handled automatically:

- `GetArrayItem(child: GetStructFieldObject())`
- `GetArrayStructFields(child: GetArrayStructFields())`
- `GetMap(value: GetStructFieldObject())`

This PR removes those cases and simplifies the extractor by passing down the data type instead of a field.

## How was this patch tested?
Existing tests.

Closes apache#23397 from hvanhovell/SPARK-26495.

Authored-by: Herman van Hovell <[email protected]>
Signed-off-by: Herman van Hovell <[email protected]>
  • Loading branch information
hvanhovell committed Dec 31, 2018
1 parent c0b9db1 commit c036836
Showing 1 changed file with 41 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.execution

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -51,8 +52,6 @@ import org.apache.spark.sql.types._
* type appropriate to the complex type extractor. In our example, the name of the child expression
* is "name" and its data type is a [[org.apache.spark.sql.types.StructType]] with a single string
* field named "first".
*
* @param expr the top-level complex type extractor
*/
private[execution] object SelectedField {
def unapply(expr: Expression): Option[StructField] = {
Expand All @@ -64,71 +63,51 @@ private[execution] object SelectedField {
selectField(unaliased, None)
}

private def selectField(expr: Expression, fieldOpt: Option[StructField]): Option[StructField] = {
/**
* Convert an expression into the parts of the schema (the field) it accesses.
*/
private def selectField(expr: Expression, dataTypeOpt: Option[DataType]): Option[StructField] = {
expr match {
// No children. Returns a StructField with the attribute name or None if fieldOpt is None.
case AttributeReference(name, dataType, nullable, metadata) =>
fieldOpt.map(field =>
StructField(name, wrapStructType(dataType, field), nullable, metadata))
// Handles case "expr0.field[n]", where "expr0" is of struct type and "expr0.field" is of
// array type.
case GetArrayItem(x @ GetStructFieldObject(child, field @ StructField(name,
dataType, nullable, metadata)), _) =>
val childField = fieldOpt.map(field => StructField(name,
wrapStructType(dataType, field), nullable, metadata)).getOrElse(field)
selectField(child, Some(childField))
// Handles case "expr0.field[n]", where "expr0.field" is of array type.
case GetArrayItem(child, _) =>
selectField(child, fieldOpt)
// Handles case "expr0.field.subfield", where "expr0" and "expr0.field" are of array type.
case GetArrayStructFields(child: GetArrayStructFields,
field @ StructField(name, dataType, nullable, metadata), _, _, _) =>
val childField = fieldOpt.map(field => StructField(name,
wrapStructType(dataType, field),
nullable, metadata)).orElse(Some(field))
selectField(child, childField)
// Handles case "expr0.field", where "expr0" is of array type.
case GetArrayStructFields(child,
field @ StructField(name, dataType, nullable, metadata), _, _, _) =>
val childField =
fieldOpt.map(field => StructField(name,
wrapStructType(dataType, field),
nullable, metadata)).orElse(Some(field))
selectField(child, childField)
// Handles case "expr0.field[key]", where "expr0" is of struct type and "expr0.field" is of
// map type.
case GetMapValue(x @ GetStructFieldObject(child, field @ StructField(name,
dataType,
nullable, metadata)), _) =>
val childField = fieldOpt.map(field => StructField(name,
wrapStructType(dataType, field),
nullable, metadata)).orElse(Some(field))
selectField(child, childField)
// Handles case "expr0.field[key]", where "expr0.field" is of map type.
case a: Attribute =>
dataTypeOpt.map { dt =>
StructField(a.name, dt, a.nullable)
}
case c: GetStructField =>
val field = c.childSchema(c.ordinal)
val newField = field.copy(dataType = dataTypeOpt.getOrElse(field.dataType))
selectField(c.child, Option(struct(newField)))
case GetArrayStructFields(child, field, _, _, containsNull) =>
val newFieldDataType = dataTypeOpt match {
case None =>
// GetArrayStructFields is the top level extractor. This means its result is
// not pruned and we need to use the element type of the array its producing.
field.dataType
case Some(ArrayType(dataType, _)) =>
// GetArrayStructFields is part of a chain of extractors and its result is pruned
// by a parent expression. In this case need to use the parent element type.
dataType
case Some(x) =>
// This should not happen.
throw new AnalysisException(s"DataType '$x' is not supported by GetArrayStructFields.")
}
val newField = StructField(field.name, newFieldDataType, field.nullable)
selectField(child, Option(ArrayType(struct(newField), containsNull)))
case GetMapValue(child, _) =>
selectField(child, fieldOpt)
// Handles case "expr0.field", where expr0 is of struct type.
case GetStructFieldObject(child,
field @ StructField(name, dataType, nullable, metadata)) =>
val childField = fieldOpt.map(field => StructField(name,
wrapStructType(dataType, field),
nullable, metadata)).orElse(Some(field))
selectField(child, childField)
// GetMapValue does not select a field from a struct (i.e. prune the struct) so it can't be
// the top-level extractor. However it can be part of an extractor chain.
val MapType(keyType, _, valueContainsNull) = child.dataType
val opt = dataTypeOpt.map(dt => MapType(keyType, dt, valueContainsNull))
selectField(child, opt)
case GetArrayItem(child, _) =>
// GetArrayItem does not select a field from a struct (i.e. prune the struct) so it can't be
// the top-level extractor. However it can be part of an extractor chain.
val ArrayType(_, containsNull) = child.dataType
val opt = dataTypeOpt.map(dt => ArrayType(dt, containsNull))
selectField(child, opt)
case _ =>
None
}
}

// Constructs a composition of complex types with a StructType(Array(field)) at its core. Returns
// a StructType for a StructType, an ArrayType for an ArrayType and a MapType for a MapType.
private def wrapStructType(dataType: DataType, field: StructField): DataType = {
dataType match {
case _: StructType =>
StructType(Array(field))
case ArrayType(elementType, containsNull) =>
ArrayType(wrapStructType(elementType, field), containsNull)
case MapType(keyType, valueType, valueContainsNull) =>
MapType(keyType, wrapStructType(valueType, field), valueContainsNull)
}
}
private def struct(field: StructField): StructType = StructType(Array(field))
}

0 comments on commit c036836

Please sign in to comment.