Skip to content

Commit

Permalink
Correctly reading empty fields in as null rather than throwing except…
Browse files Browse the repository at this point in the history
…ion (elastic#1816)

By default we intend to treat empty fields as nulls when being read in through spark sql. However we actually
turn them into None objects, which causes spark-sql to blow up in spark 2 and 3. This commit treats them
as nulls, which works for all versions of spark we currently support.
Closes elastic#1635
  • Loading branch information
masseyke authored Dec 13, 2021
1 parent ea9b62d commit 52c264f
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class ScalaValueReader extends AbstractValueReader with SettingsAware {
}
}

def nullValue() = { None }
def nullValue() = { null }
def textValue(value: String, parser: Parser) = { checkNull (parseText, value, parser) }
protected def parseText(value:String, parser: Parser) = { value }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class ScalaExtendedBooleanValueReaderTest(jsonString: String, expected: Expected

def isNull: Matcher[AnyRef] = {
return new BaseMatcher[AnyRef] {
override def matches(item: scala.Any): Boolean = item == None
override def matches(item: scala.Any): Boolean = item == null
override def describeTo(description: Description): Unit = description.appendText("null")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ class ScalaValueReaderTest extends BaseValueReaderTest {

override def createValueReader() = new ScalaValueReader()

override def checkNull(result: Object): Unit = { assertEquals(None, result)}
override def checkEmptyString(result: Object): Unit = { assertEquals(None, result)}
override def checkNull(result: Object): Unit = { assertEquals(null, result)}
override def checkEmptyString(result: Object): Unit = { assertEquals(null, result)}
override def checkInteger(result: Object): Unit = { assertEquals(Int.MaxValue, result)}
override def checkLong(result: Object): Unit = { assertEquals(Long.MaxValue, result)}
override def checkDouble(result: Object): Unit = { assertEquals(Double.MaxValue, result)}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import java.{lang => jl}
import java.sql.Timestamp
import java.{util => ju}
import java.util.concurrent.TimeUnit

import scala.collection.JavaConversions.propertiesAsScalaMap
import scala.collection.JavaConverters.asScalaBufferConverter
import scala.collection.JavaConverters.mapAsJavaMapConverter
Expand Down Expand Up @@ -68,6 +67,8 @@ import org.junit.runners.Parameterized.Parameters
import org.junit.runners.MethodSorters
import com.esotericsoftware.kryo.io.{Input => KryoInput}
import com.esotericsoftware.kryo.io.{Output => KryoOutput}
import org.apache.spark.rdd.RDD

import javax.xml.bind.DatatypeConverter
import org.elasticsearch.hadoop.{EsHadoopIllegalArgumentException, EsHadoopIllegalStateException}
import org.apache.spark.sql.types.DoubleType
Expand Down Expand Up @@ -419,6 +420,33 @@ class AbstractScalaEsScalaSparkSQL(prefix: String, readMetadata: jl.Boolean, pus
val results = sqc.sql("SELECT name FROM datfile WHERE id >=1 AND id <=10")
//results.take(5).foreach(println)
}

@Test
def testEmptyStrings(): Unit = {
val data = Seq(("Java", "20000"), ("Python", ""), ("Scala", "3000"))
val rdd: RDD[Row] = sc.parallelize(data).map(row => Row(row._1, row._2))
val schema = StructType( Array(
StructField("language", StringType,true),
StructField("description", StringType,true)
))
val inputDf = sqc.createDataFrame(rdd, schema)
inputDf.write
.format("org.elasticsearch.spark.sql")
.save("empty_strings_test")
val reader = sqc.read.format("org.elasticsearch.spark.sql")
val outputDf = reader.load("empty_strings_test")
assertEquals(data.size, outputDf.count)
val nullDescriptionsDf = outputDf.filter("language = 'Python'")
assertEquals(1, nullDescriptionsDf.count)
assertEquals(null, nullDescriptionsDf.first().getAs("description"))

val reader2 = sqc.read.format("org.elasticsearch.spark.sql").option("es.field.read.empty.as.null", "no")
val outputDf2 = reader2.load("empty_strings_test")
assertEquals(data.size, outputDf2.count)
val emptyDescriptionsDf = outputDf2.filter("language = 'Python'")
assertEquals(1, emptyDescriptionsDf.count)
assertEquals("", emptyDescriptionsDf.first().getAs("description"))
}

@Test
def test0WriteFieldNameWithPercentage() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import java.nio.file.Paths
import java.sql.Timestamp
import java.{util => ju}
import java.util.concurrent.TimeUnit

import scala.collection.JavaConversions.propertiesAsScalaMap
import scala.collection.JavaConverters.asScalaBufferConverter
import scala.collection.JavaConverters.mapAsJavaMapConverter
Expand Down Expand Up @@ -86,6 +85,8 @@ import org.junit.runners.Parameterized
import org.junit.runners.Parameterized.Parameters
import com.esotericsoftware.kryo.io.{Input => KryoInput}
import com.esotericsoftware.kryo.io.{Output => KryoOutput}
import org.apache.spark.rdd.RDD

import javax.xml.bind.DatatypeConverter
import org.apache.spark.sql.SparkSession
import org.elasticsearch.hadoop.EsAssume
Expand Down Expand Up @@ -438,6 +439,33 @@ class AbstractScalaEsScalaSparkSQL(prefix: String, readMetadata: jl.Boolean, pus
val results = sqc.sql("SELECT name FROM datfile WHERE id >=1 AND id <=10")
//results.take(5).foreach(println)
}

@Test
def testEmptyStrings(): Unit = {
val data = Seq(("Java", "20000"), ("Python", ""), ("Scala", "3000"))
val rdd: RDD[Row] = sc.parallelize(data).map(row => Row(row._1, row._2))
val schema = StructType( Array(
StructField("language", StringType,true),
StructField("description", StringType,true)
))
val inputDf = sqc.createDataFrame(rdd, schema)
inputDf.write
.format("org.elasticsearch.spark.sql")
.save("empty_strings_test")
val reader = sqc.read.format("org.elasticsearch.spark.sql")
val outputDf = reader.load("empty_strings_test")
assertEquals(data.size, outputDf.count)
val nullDescriptionsDf = outputDf.filter(row => row.getAs("description") == null)
assertEquals(1, nullDescriptionsDf.count)

val reader2 = sqc.read.format("org.elasticsearch.spark.sql").option("es.field.read.empty.as.null", "no")
val outputDf2 = reader2.load("empty_strings_test")
assertEquals(data.size, outputDf2.count)
val nullDescriptionsDf2 = outputDf2.filter(row => row.getAs("description") == null)
assertEquals(0, nullDescriptionsDf2.count)
val emptyDescriptionsDf = outputDf2.filter(row => row.getAs("description") == "")
assertEquals(1, emptyDescriptionsDf.count)
}

@Test
def test0WriteFieldNameWithPercentage() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import java.nio.file.Paths
import java.sql.Timestamp
import java.{util => ju}
import java.util.concurrent.TimeUnit

import scala.collection.JavaConversions.propertiesAsScalaMap
import scala.collection.JavaConverters.asScalaBufferConverter
import scala.collection.JavaConverters.mapAsJavaMapConverter
Expand Down Expand Up @@ -86,6 +85,8 @@ import org.junit.runners.Parameterized
import org.junit.runners.Parameterized.Parameters
import com.esotericsoftware.kryo.io.{Input => KryoInput}
import com.esotericsoftware.kryo.io.{Output => KryoOutput}
import org.apache.spark.rdd.RDD

import javax.xml.bind.DatatypeConverter
import org.apache.spark.sql.SparkSession
import org.elasticsearch.hadoop.EsAssume
Expand All @@ -98,6 +99,7 @@ import org.junit.Assert._
import org.junit.ClassRule

object AbstractScalaEsScalaSparkSQL {

@transient val conf = new SparkConf()
.setAll(propertiesAsScalaMap(TestSettings.TESTING_PROPS))
.setAppName("estest")
Expand Down Expand Up @@ -438,7 +440,34 @@ class AbstractScalaEsScalaSparkSQL(prefix: String, readMetadata: jl.Boolean, pus
val results = sqc.sql("SELECT name FROM datfile WHERE id >=1 AND id <=10")
//results.take(5).foreach(println)
}


@Test
def testEmptyStrings(): Unit = {
val data = Seq(("Java", "20000"), ("Python", ""), ("Scala", "3000"))
val rdd: RDD[Row] = sc.parallelize(data).map(row => Row(row._1, row._2))
val schema = StructType( Array(
StructField("language", StringType,true),
StructField("description", StringType,true)
))
val inputDf = sqc.createDataFrame(rdd, schema)
inputDf.write
.format("org.elasticsearch.spark.sql")
.save("empty_strings_test")
val reader = sqc.read.format("org.elasticsearch.spark.sql")
val outputDf = reader.load("empty_strings_test")
assertEquals(data.size, outputDf.count)
val nullDescriptionsDf = outputDf.filter(row => row.getAs("description") == null)
assertEquals(1, nullDescriptionsDf.count)

val reader2 = sqc.read.format("org.elasticsearch.spark.sql").option("es.field.read.empty.as.null", "no")
val outputDf2 = reader2.load("empty_strings_test")
assertEquals(data.size, outputDf2.count)
val nullDescriptionsDf2 = outputDf2.filter(row => row.getAs("description") == null)
assertEquals(0, nullDescriptionsDf2.count)
val emptyDescriptionsDf = outputDf2.filter(row => row.getAs("description") == "")
assertEquals(1, emptyDescriptionsDf.count)
}

@Test
def test0WriteFieldNameWithPercentage() {
val index = wrapIndex("spark-test-scala-sql-field-with-percentage")
Expand Down

0 comments on commit 52c264f

Please sign in to comment.