Skip to content

Commit

Permalink
Roundtrip null values of any type
Browse files Browse the repository at this point in the history
This pull request adds functionality to spark-csv with the goal of having the ability to write null values to file and read them back out again as null. Two changes were made to enable this.

First, since the `com.databricks.spark.csv` package previously had the null string hardcoded to "`null`" when saving to a csv file, this was changed to read the null token out of the passed in parameters map, from the value for "`nullToken`", enabling writing null values as empty strings by use of this option. The default is left to "`null`" to maintain the previous behavior of the library.

Secondly, the `castTo` method from `com.databricks.spark.csv.util.TypeCast` had an impossible-to-reach case statement when the `castType` was an instance of `StringType`. As a result, it was not possible to read string values from file as null. This pull request adds a setting 'treatEmptyValuesAsNulls' that allows empty string values in fields that are marked as nullable to be read as null values, as expected. Again, the previous behavior is enabled by default, so this pull request only changes the behavior when `treatEmptyValuesAsNulls` is explicitly set to true. The appropriate changes to `CsvParser` and `CsvRelation` were made to include this new setting.

Additionally, a unit test has been added to `CsvSuite` to test the ability to round-trip (both string and non-string) null values by writing nulls and reading them back out again as nulls.

Author: Andres Perez <[email protected]>

Closes databricks#147 from andy327/feat-set-null-tokens.
  • Loading branch information
andy327 authored and falaki committed Oct 4, 2015
1 parent ad11f75 commit ee152f3
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 5 deletions.
1 change: 1 addition & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ mimaDefaultSettings ++ Seq(
ProblemFilters.excludePackage("com.databricks.spark.csv.CsvRelation"),
ProblemFilters.excludePackage("com.databricks.spark.csv.util.InferSchema"),
ProblemFilters.excludePackage("com.databricks.spark.sql.readers"),
ProblemFilters.excludePackage("com.databricks.spark.csv.util.TypeCast"),
// We allowed the private `CsvRelation` type to leak into the public method signature:
ProblemFilters.exclude[IncompatibleResultTypeProblem](
"com.databricks.spark.csv.DefaultSource.createRelation")
Expand Down
8 changes: 8 additions & 0 deletions src/main/scala/com/databricks/spark/csv/CsvParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class CsvParser extends Serializable {
private var parseMode: String = ParseModes.DEFAULT
private var ignoreLeadingWhiteSpace: Boolean = false
private var ignoreTrailingWhiteSpace: Boolean = false
private var treatEmptyValuesAsNulls: Boolean = false
private var parserLib: String = ParserLibs.DEFAULT
private var charset: String = TextFile.DEFAULT_CHARSET.name()
private var inferSchema: Boolean = false
Expand Down Expand Up @@ -84,6 +85,11 @@ class CsvParser extends Serializable {
this
}

def withTreatEmptyValuesAsNulls(treatAsNull: Boolean): CsvParser = {
this.treatEmptyValuesAsNulls = treatAsNull
this
}

def withParserLib(parserLib: String): CsvParser = {
this.parserLib = parserLib
this
Expand Down Expand Up @@ -114,6 +120,7 @@ class CsvParser extends Serializable {
parserLib,
ignoreLeadingWhiteSpace,
ignoreTrailingWhiteSpace,
treatEmptyValuesAsNulls,
schema,
inferSchema)(sqlContext)
sqlContext.baseRelationToDataFrame(relation)
Expand All @@ -132,6 +139,7 @@ class CsvParser extends Serializable {
parserLib,
ignoreLeadingWhiteSpace,
ignoreTrailingWhiteSpace,
treatEmptyValuesAsNulls,
schema,
inferSchema)(sqlContext)
sqlContext.baseRelationToDataFrame(relation)
Expand Down
4 changes: 3 additions & 1 deletion src/main/scala/com/databricks/spark/csv/CsvRelation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ case class CsvRelation protected[spark] (
parserLib: String,
ignoreLeadingWhiteSpace: Boolean,
ignoreTrailingWhiteSpace: Boolean,
treatEmptyValuesAsNulls: Boolean,
userSchema: StructType = null,
inferCsvSchema: Boolean)(@transient val sqlContext: SQLContext)
extends BaseRelation with TableScan with InsertableRelation {
Expand Down Expand Up @@ -113,7 +114,8 @@ case class CsvRelation protected[spark] (
index = 0
while (index < schemaFields.length) {
val field = schemaFields(index)
rowArray(index) = TypeCast.castTo(tokens(index), field.dataType, field.nullable)
rowArray(index) = TypeCast.castTo(tokens(index), field.dataType, field.nullable,
treatEmptyValuesAsNulls)
index = index + 1
}
Some(Row.fromSeq(rowArray))
Expand Down
9 changes: 9 additions & 0 deletions src/main/scala/com/databricks/spark/csv/DefaultSource.scala
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,14 @@ class DefaultSource
} else {
throw new Exception("Ignore white space flag can be true or false")
}
val treatEmptyValuesAsNulls = parameters.getOrElse("treatEmptyValuesAsNulls", "false")
val treatEmptyValuesAsNullsFlag = if (treatEmptyValuesAsNulls == "false") {
false
} else if (treatEmptyValuesAsNulls == "true") {
true
} else {
throw new Exception("Treat empty values as null flag can be true or false")
}

val charset = parameters.getOrElse("charset", TextFile.DEFAULT_CHARSET.name())
// TODO validate charset?
Expand All @@ -137,6 +145,7 @@ class DefaultSource
parserLib,
ignoreLeadingWhiteSpaceFlag,
ignoreTrailingWhiteSpaceFlag,
treatEmptyValuesAsNullsFlag,
schema,
inferSchemaFlag)(sqlContext)
}
Expand Down
8 changes: 6 additions & 2 deletions src/main/scala/com/databricks/spark/csv/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ package object csv {
parserLib = parserLib,
ignoreLeadingWhiteSpace = ignoreLeadingWhiteSpace,
ignoreTrailingWhiteSpace = ignoreTrailingWhiteSpace,
treatEmptyValuesAsNulls = false,
inferCsvSchema = inferSchema)(sqlContext)
sqlContext.baseRelationToDataFrame(csvRelation)
}
Expand All @@ -76,6 +77,7 @@ package object csv {
parserLib = parserLib,
ignoreLeadingWhiteSpace = ignoreLeadingWhiteSpace,
ignoreTrailingWhiteSpace = ignoreTrailingWhiteSpace,
treatEmptyValuesAsNulls = false,
inferCsvSchema = inferSchema)(sqlContext)
sqlContext.baseRelationToDataFrame(csvRelation)
}
Expand Down Expand Up @@ -116,11 +118,13 @@ package object csv {
case None => None
}

val nullValue = parameters.getOrElse("nullValue", "null")

val csvFormatBase = CSVFormat.DEFAULT
.withDelimiter(delimiterChar)
.withEscape(escapeChar)
.withSkipHeaderRecord(false)
.withNullString("null")
.withNullString(nullValue)

val csvFormat = quoteChar match {
case Some(c) => csvFormatBase.withQuote(c)
Expand All @@ -139,7 +143,7 @@ package object csv {
.withDelimiter(delimiterChar)
.withEscape(escapeChar)
.withSkipHeaderRecord(false)
.withNullString("null")
.withNullString(nullValue)

val csvFormat = quoteChar match {
case Some(c) => csvFormatBase.withQuote(c)
Expand Down
8 changes: 6 additions & 2 deletions src/main/scala/com/databricks/spark/csv/util/TypeCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,12 @@ object TypeCast {
* @param datum string value
* @param castType SparkSQL type
*/
private[csv] def castTo(datum: String, castType: DataType, nullable: Boolean = true): Any = {
if (datum == "" && nullable && !castType.isInstanceOf[StringType]){
private[csv] def castTo(
datum: String,
castType: DataType,
nullable: Boolean = true,
treatEmptyValuesAsNulls: Boolean = false): Any = {
if (datum == "" && nullable && (!castType.isInstanceOf[StringType] || treatEmptyValuesAsNulls)){
null
} else {
castType match {
Expand Down
24 changes: 24 additions & 0 deletions src/test/scala/com/databricks/spark/csv/CsvSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,30 @@ abstract class AbstractCsvSuite extends FunSuite with BeforeAndAfterAll {
assert(exception.getMessage.contains("Malformed line in FAILFAST mode: 2015,Chevy,Volt"))
}

test("DSL test roundtrip nulls") {
// Create temp directory
TestUtils.deleteRecursively(new File(tempEmptyDir))
new File(tempEmptyDir).mkdirs()
val copyFilePath = tempEmptyDir + "null-numbers.csv"
val agesSchema = StructType(List(StructField("name", StringType, true),
StructField("age", IntegerType, true)))

val agesRows = Seq(Row("alice", 35), Row("bob", null), Row(null, 24))
val agesRdd = sqlContext.sparkContext.parallelize(agesRows)
val agesDf = sqlContext.createDataFrame(agesRdd, agesSchema)

agesDf.saveAsCsvFile(copyFilePath, Map("header" -> "true", "nullValue" -> ""))

val agesCopy = new CsvParser()
.withSchema(agesSchema)
.withUseHeader(true)
.withTreatEmptyValuesAsNulls(true)
.withParserLib(parserLib)
.csvFile(sqlContext, copyFilePath)

assert(agesCopy.count == agesRows.size)
assert(agesCopy.collect.toSet == agesRows.toSet)
}

test("DSL test with alternative delimiter and quote") {
val results = new CsvParser()
Expand Down

0 comments on commit ee152f3

Please sign in to comment.