diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 7efeab528c..2156806d21 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -350,28 +350,47 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { withTempPath { dir => withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") { import testImplicits._ - val path = dir.getCanonicalPath - sqlContext.range(10).coalesce(1).write.orc(path) + + // For field "a", the first column has odds integers. This is to check the filtered count + // when `isNull` is performed. For Field "b", `isNotNull` of ORC file filters rows + // only when all the values are null (maybe this works differently when the data + // or query is complicated). So, simply here a column only having `null` is added. + val data = (0 until 10).map { i => + val maybeInt = if (i % 2 == 0) None else Some(i) + val nullValue: Option[String] = None + (maybeInt, nullValue) + } + createDataFrame(data).toDF("a", "b").write.orc(path) val df = sqlContext.read.orc(path) - def checkPredicate(pred: Column, answer: Seq[Long]): Unit = { - checkAnswer(df.where(pred), answer.map(Row(_))) + def checkPredicate(pred: Column, answer: Seq[Row]): Unit = { + val sourceDf = stripSparkFilter(df.where(pred)) + val data = sourceDf.collect().toSet + val expectedData = answer.toSet + + // When a filter is pushed to ORC, ORC can apply it to rows. So, we can check + // the number of rows returned from the ORC to make sure our filter pushdown work. + // A tricky part is, ORC does not process filter rows fully but return some possible + // results. So, this checks if the number of result is less than the original count + // of data, and then checks if it contains the expected data. + val isOrcFiltered = sourceDf.count < 10 && expectedData.subsetOf(data) + assert(isOrcFiltered) } - checkPredicate('id === 5, Seq(5L)) - checkPredicate('id <=> 5, Seq(5L)) - checkPredicate('id < 5, 0L to 4L) - checkPredicate('id <= 5, 0L to 5L) - checkPredicate('id > 5, 6L to 9L) - checkPredicate('id >= 5, 5L to 9L) - checkPredicate('id.isNull, Seq.empty[Long]) - checkPredicate('id.isNotNull, 0L to 9L) - checkPredicate('id.isin(1L, 3L, 5L), Seq(1L, 3L, 5L)) - checkPredicate('id > 0 && 'id < 3, 1L to 2L) - checkPredicate('id < 1 || 'id > 8, Seq(0L, 9L)) - checkPredicate(!('id > 3), 0L to 3L) - checkPredicate(!('id > 0 && 'id < 3), Seq(0L) ++ (3L to 9L)) + checkPredicate('a === 5, List(5).map(Row(_, null))) + checkPredicate('a <=> 5, List(5).map(Row(_, null))) + checkPredicate('a < 5, List(1, 3).map(Row(_, null))) + checkPredicate('a <= 5, List(1, 3, 5).map(Row(_, null))) + checkPredicate('a > 5, List(7, 9).map(Row(_, null))) + checkPredicate('a >= 5, List(5, 7, 9).map(Row(_, null))) + checkPredicate('a.isNull, List(null).map(Row(_, null))) + checkPredicate('b.isNotNull, List()) + checkPredicate('a.isin(3, 5, 7), List(3, 5, 7).map(Row(_, null))) + checkPredicate('a > 0 && 'a < 3, List(1).map(Row(_, null))) + checkPredicate('a < 1 || 'a > 8, List(9).map(Row(_, null))) + checkPredicate(!('a > 3), List(1, 3).map(Row(_, null))) + checkPredicate(!('a > 0 && 'a < 3), List(3, 5, 7, 9).map(Row(_, null))) } } }