From 4a5ee81d94d02e1bf8b2e1f437aeb9134ff14719 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 28 Dec 2015 15:02:33 -0800 Subject: [PATCH] Inconsistent result for DROPMALFORMED with pruned scan and unnecessarily casting all values when no fields are required https://github.com/databricks/spark-csv/issues/218 https://github.com/databricks/spark-csv/issues/219 In this PR, I made the pruned scan try to parse all the values in columns when DROPMALFORMED is enabled and return only required fields. In addition, I changed the condition for table scan. If required columns are empty, then it just produces empty rows. Author: hyukjinkwon Closes #220 from HyukjinKwon/ISSUE-218-non-all-cast. --- .../databricks/spark/csv/CsvRelation.scala | 29 ++++++++++++------- .../com/databricks/spark/csv/CsvSuite.scala | 23 ++++++++++++++- 2 files changed, 40 insertions(+), 12 deletions(-) diff --git a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala index 196734d..dcab9c8 100755 --- a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala +++ b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala @@ -147,18 +147,25 @@ case class CsvRelation protected[spark] ( override def buildScan(requiredColumns: Array[String]): RDD[Row] = { val schemaFields = schema.fields val requiredFields = StructType(requiredColumns.map(schema(_))).fields - val isTableScan = requiredColumns.isEmpty || schemaFields.deep == requiredFields.deep - if (isTableScan) { + val shouldTableScan = schemaFields.deep == requiredFields.deep + val safeRequiredFields = if (dropMalformed) { + // If `dropMalformed` is enabled, then it needs to parse all the values + // so that we can decide which row is malformed. + requiredFields ++ schemaFields.filterNot(requiredFields.contains(_)) + } else { + requiredFields + } + if (shouldTableScan) { buildScan } else { - val requiredIndices = new Array[Int](requiredFields.length) + val safeRequiredIndices = new Array[Int](safeRequiredFields.length) schemaFields.zipWithIndex.filter { - case (field, _) => requiredFields.contains(field) + case (field, _) => safeRequiredFields.contains(field) }.foreach { - case(field, index) => requiredIndices(requiredFields.indexOf(field)) = index + case (field, index) => safeRequiredIndices(safeRequiredFields.indexOf(field)) = index } - - val rowArray = new Array[Any](requiredIndices.length) + val rowArray = new Array[Any](safeRequiredIndices.length) + val requiredSize = requiredFields.length tokenRdd(schemaFields.map(_.name)).flatMap { tokens => if (dropMalformed && schemaFields.length != tokens.size) { logger.warn(s"Dropping malformed line: ${tokens.mkString(delimiter.toString)}") @@ -168,15 +175,15 @@ case class CsvRelation protected[spark] ( s"${tokens.mkString(delimiter.toString)}") } else { val indexSafeTokens = if (permissive && schemaFields.length != tokens.size) { - tokens ++ new Array[String](schemaFields.length - tokens.length) + tokens ++ new Array[String](schemaFields.length - tokens.size) } else { tokens } try { var index: Int = 0 var subIndex: Int = 0 - while (subIndex < requiredIndices.length) { - index = requiredIndices(subIndex) + while (subIndex < safeRequiredIndices.length) { + index = safeRequiredIndices(subIndex) val field = schemaFields(index) rowArray(subIndex) = TypeCast.castTo( indexSafeTokens(index), @@ -185,7 +192,7 @@ case class CsvRelation protected[spark] ( treatEmptyValuesAsNulls) subIndex = subIndex + 1 } - Some(Row.fromSeq(rowArray)) + Some(Row.fromSeq(rowArray.take(requiredSize))) } catch { case nfe: java.lang.NumberFormatException if dropMalformed => logger.warn("Number format exception. " + diff --git a/src/test/scala/com/databricks/spark/csv/CsvSuite.scala b/src/test/scala/com/databricks/spark/csv/CsvSuite.scala index 61ccfdc..9acc7f3 100755 --- a/src/test/scala/com/databricks/spark/csv/CsvSuite.scala +++ b/src/test/scala/com/databricks/spark/csv/CsvSuite.scala @@ -169,6 +169,27 @@ abstract class AbstractCsvSuite extends FunSuite with BeforeAndAfterAll { assert(results.size === numCars - 1) } + test("DSL test for DROPMALFORMED parsing mode with pruned scan") { + val strictSchema = new StructType( + Array( + StructField("Name", StringType, true), + StructField("Age", IntegerType, true), + StructField("Height", DoubleType, true) + ) + ) + + val results = new CsvParser() + .withSchema(strictSchema) + .withUseHeader(true) + .withParserLib(parserLib) + .withParseMode(ParseModes.DROP_MALFORMED_MODE) + .csvFile(sqlContext, ageFile) + .select("Name") + .collect().size + + assert(results === 1) + } + test("DSL test for FAILFAST parsing mode") { val parser = new CsvParser() .withParseMode(ParseModes.FAIL_FAST_MODE) @@ -289,7 +310,6 @@ abstract class AbstractCsvSuite extends FunSuite with BeforeAndAfterAll { assert(sqlContext.sql("SELECT year FROM carsTable").collect().size === numCars) } - test("DSL test with empty file and known schema") { val results = new CsvParser() .withSchema(StructType(List(StructField("column", StringType, false)))) @@ -320,6 +340,7 @@ abstract class AbstractCsvSuite extends FunSuite with BeforeAndAfterAll { assert(results === 3) } + test("DSL test with poorly formatted file and known schema") { val strictSchema = new StructType( Array(