Skip to content

Commit

Permalink
Inconsistent result for DROPMALFORMED with pruned scan and unnecessar…
Browse files Browse the repository at this point in the history
…ily casting all values when no fields are required

https://github.com/databricks/spark-csv/issues/218 https://github.com/databricks/spark-csv/issues/219

In this PR, I made the pruned scan try to parse all the values in columns when DROPMALFORMED is enabled and return only required fields.

In addition, I changed the condition for table scan. If required columns are empty, then it just produces empty rows.

Author: hyukjinkwon <[email protected]>

Closes databricks#220 from HyukjinKwon/ISSUE-218-non-all-cast.
  • Loading branch information
HyukjinKwon authored and falaki committed Dec 28, 2015
1 parent fb1976d commit 4a5ee81
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 12 deletions.
29 changes: 18 additions & 11 deletions src/main/scala/com/databricks/spark/csv/CsvRelation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -147,18 +147,25 @@ case class CsvRelation protected[spark] (
override def buildScan(requiredColumns: Array[String]): RDD[Row] = {
val schemaFields = schema.fields
val requiredFields = StructType(requiredColumns.map(schema(_))).fields
val isTableScan = requiredColumns.isEmpty || schemaFields.deep == requiredFields.deep
if (isTableScan) {
val shouldTableScan = schemaFields.deep == requiredFields.deep
val safeRequiredFields = if (dropMalformed) {
// If `dropMalformed` is enabled, then it needs to parse all the values
// so that we can decide which row is malformed.
requiredFields ++ schemaFields.filterNot(requiredFields.contains(_))
} else {
requiredFields
}
if (shouldTableScan) {
buildScan
} else {
val requiredIndices = new Array[Int](requiredFields.length)
val safeRequiredIndices = new Array[Int](safeRequiredFields.length)
schemaFields.zipWithIndex.filter {
case (field, _) => requiredFields.contains(field)
case (field, _) => safeRequiredFields.contains(field)
}.foreach {
case(field, index) => requiredIndices(requiredFields.indexOf(field)) = index
case (field, index) => safeRequiredIndices(safeRequiredFields.indexOf(field)) = index
}

val rowArray = new Array[Any](requiredIndices.length)
val rowArray = new Array[Any](safeRequiredIndices.length)
val requiredSize = requiredFields.length
tokenRdd(schemaFields.map(_.name)).flatMap { tokens =>
if (dropMalformed && schemaFields.length != tokens.size) {
logger.warn(s"Dropping malformed line: ${tokens.mkString(delimiter.toString)}")
Expand All @@ -168,15 +175,15 @@ case class CsvRelation protected[spark] (
s"${tokens.mkString(delimiter.toString)}")
} else {
val indexSafeTokens = if (permissive && schemaFields.length != tokens.size) {
tokens ++ new Array[String](schemaFields.length - tokens.length)
tokens ++ new Array[String](schemaFields.length - tokens.size)
} else {
tokens
}
try {
var index: Int = 0
var subIndex: Int = 0
while (subIndex < requiredIndices.length) {
index = requiredIndices(subIndex)
while (subIndex < safeRequiredIndices.length) {
index = safeRequiredIndices(subIndex)
val field = schemaFields(index)
rowArray(subIndex) = TypeCast.castTo(
indexSafeTokens(index),
Expand All @@ -185,7 +192,7 @@ case class CsvRelation protected[spark] (
treatEmptyValuesAsNulls)
subIndex = subIndex + 1
}
Some(Row.fromSeq(rowArray))
Some(Row.fromSeq(rowArray.take(requiredSize)))
} catch {
case nfe: java.lang.NumberFormatException if dropMalformed =>
logger.warn("Number format exception. " +
Expand Down
23 changes: 22 additions & 1 deletion src/test/scala/com/databricks/spark/csv/CsvSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,27 @@ abstract class AbstractCsvSuite extends FunSuite with BeforeAndAfterAll {
assert(results.size === numCars - 1)
}

test("DSL test for DROPMALFORMED parsing mode with pruned scan") {
val strictSchema = new StructType(
Array(
StructField("Name", StringType, true),
StructField("Age", IntegerType, true),
StructField("Height", DoubleType, true)
)
)

val results = new CsvParser()
.withSchema(strictSchema)
.withUseHeader(true)
.withParserLib(parserLib)
.withParseMode(ParseModes.DROP_MALFORMED_MODE)
.csvFile(sqlContext, ageFile)
.select("Name")
.collect().size

assert(results === 1)
}

test("DSL test for FAILFAST parsing mode") {
val parser = new CsvParser()
.withParseMode(ParseModes.FAIL_FAST_MODE)
Expand Down Expand Up @@ -289,7 +310,6 @@ abstract class AbstractCsvSuite extends FunSuite with BeforeAndAfterAll {
assert(sqlContext.sql("SELECT year FROM carsTable").collect().size === numCars)
}


test("DSL test with empty file and known schema") {
val results = new CsvParser()
.withSchema(StructType(List(StructField("column", StringType, false))))
Expand Down Expand Up @@ -320,6 +340,7 @@ abstract class AbstractCsvSuite extends FunSuite with BeforeAndAfterAll {

assert(results === 3)
}

test("DSL test with poorly formatted file and known schema") {
val strictSchema = new StructType(
Array(
Expand Down

0 comments on commit 4a5ee81

Please sign in to comment.